From bd2bfbdad1ea1680eda6b4a397c354ec592d5630 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 7 Jun 2022 18:55:35 +0000 Subject: [PATCH 001/303] Internal change PiperOrigin-RevId: 453491226 --- base/BUILD | 102 +++++- base/internal/type.post.h | 45 +-- base/internal/type.pre.h | 8 + base/type.cc | 164 +-------- base/type.h | 632 +---------------------------------- base/type_factory.cc | 1 - base/type_factory.h | 18 +- base/type_manager.h | 4 +- base/type_test.cc | 4 +- base/types/BUILD | 241 +++++++++++++ base/types/any_type.cc | 28 ++ base/types/any_type.h | 56 ++++ base/types/bool_type.cc | 28 ++ base/types/bool_type.h | 57 ++++ base/types/bytes_type.cc | 28 ++ base/types/bytes_type.h | 57 ++++ base/types/double_type.cc | 28 ++ base/types/double_type.h | 57 ++++ base/types/duration_type.cc | 28 ++ base/types/duration_type.h | 57 ++++ base/types/dyn_type.cc | 28 ++ base/types/dyn_type.h | 56 ++++ base/types/enum_type.cc | 41 +++ base/types/enum_type.h | 158 +++++++++ base/types/error_type.cc | 28 ++ base/types/error_type.h | 57 ++++ base/types/int_type.cc | 28 ++ base/types/int_type.h | 57 ++++ base/types/list_type.cc | 40 +++ base/types/list_type.h | 73 ++++ base/types/map_type.cc | 42 +++ base/types/map_type.h | 78 +++++ base/types/null_type.cc | 28 ++ base/types/null_type.h | 58 ++++ base/types/string_type.cc | 28 ++ base/types/string_type.h | 57 ++++ base/types/struct_type.cc | 43 +++ base/types/struct_type.h | 152 +++++++++ base/types/timestamp_type.cc | 28 ++ base/types/timestamp_type.h | 59 ++++ base/types/type_type.cc | 28 ++ base/types/type_type.h | 58 ++++ base/types/uint_type.cc | 28 ++ base/types/uint_type.h | 57 ++++ base/value.h | 17 + eval/public/structs/BUILD | 2 +- 46 files changed, 2119 insertions(+), 853 deletions(-) create mode 100644 base/types/BUILD create mode 100644 base/types/any_type.cc create mode 100644 base/types/any_type.h create mode 100644 base/types/bool_type.cc create mode 100644 base/types/bool_type.h create mode 100644 base/types/bytes_type.cc create mode 100644 base/types/bytes_type.h create mode 100644 base/types/double_type.cc create mode 100644 base/types/double_type.h create mode 100644 base/types/duration_type.cc create mode 100644 base/types/duration_type.h create mode 100644 base/types/dyn_type.cc create mode 100644 base/types/dyn_type.h create mode 100644 base/types/enum_type.cc create mode 100644 base/types/enum_type.h create mode 100644 base/types/error_type.cc create mode 100644 base/types/error_type.h create mode 100644 base/types/int_type.cc create mode 100644 base/types/int_type.h create mode 100644 base/types/list_type.cc create mode 100644 base/types/list_type.h create mode 100644 base/types/map_type.cc create mode 100644 base/types/map_type.h create mode 100644 base/types/null_type.cc create mode 100644 base/types/null_type.h create mode 100644 base/types/string_type.cc create mode 100644 base/types/string_type.h create mode 100644 base/types/struct_type.cc create mode 100644 base/types/struct_type.h create mode 100644 base/types/timestamp_type.cc create mode 100644 base/types/timestamp_type.h create mode 100644 base/types/type_type.cc create mode 100644 base/types/type_type.h create mode 100644 base/types/uint_type.cc create mode 100644 base/types/uint_type.h diff --git a/base/BUILD b/base/BUILD index 7a547dd68..052093a9a 100644 --- a/base/BUILD +++ b/base/BUILD @@ -103,35 +103,89 @@ cc_library( name = "type", srcs = [ "type.cc", - "type_factory.cc", - "type_manager.cc", - "type_provider.cc", ], hdrs = [ "type.h", - "type_factory.h", - "type_manager.h", - "type_provider.h", - "type_registry.h", ], deps = [ ":handle", ":kind", ":memory_manager", "//base/internal:type", - "//internal:casts", - "//internal:no_destructor", - "//internal:rtti", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "type_manager", + srcs = ["type_manager.cc"], + hdrs = ["type_manager.h"], + deps = [ + ":type", + ":type_factory", + ":type_provider", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "type_provider", + srcs = ["type_provider.cc"], + hdrs = ["type_provider.h"], + deps = [ + ":handle", + ":type", + ":type_factory", + "//internal:no_destructor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "type_registry", + hdrs = ["type_registry.h"], + deps = [":type_provider"], +) + +cc_library( + name = "type_factory", + srcs = ["type_factory.cc"], + hdrs = ["type_factory.h"], + deps = [ + ":handle", + ":memory_manager", + "//base/types:any", + "//base/types:bool", + "//base/types:bytes", + "//base/types:double", + "//base/types:duration", + "//base/types:dyn", + "//base/types:enum", + "//base/types:error", + "//base/types:int", + "//base/types:list", + "//base/types:map", + "//base/types:null", + "//base/types:string", + "//base/types:struct", + "//base/types:timestamp", + "//base/types:type", + "//base/types:uint", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", ], ) @@ -145,6 +199,8 @@ cc_test( ":handle", ":memory_manager", ":type", + ":type_factory", + ":type_manager", ":value", "//base/internal:memory_manager_testing", "//internal:testing", @@ -169,7 +225,25 @@ cc_library( ":kind", ":memory_manager", ":type", + ":type_manager", "//base/internal:value", + "//base/types:any", + "//base/types:bool", + "//base/types:bytes", + "//base/types:double", + "//base/types:duration", + "//base/types:dyn", + "//base/types:enum", + "//base/types:error", + "//base/types:int", + "//base/types:list", + "//base/types:map", + "//base/types:null", + "//base/types:string", + "//base/types:struct", + "//base/types:timestamp", + "//base/types:type", + "//base/types:uint", "//internal:casts", "//internal:no_destructor", "//internal:rtti", @@ -201,6 +275,8 @@ cc_test( deps = [ ":memory_manager", ":type", + ":type_factory", + ":type_manager", ":value", "//base/internal:memory_manager_testing", "//internal:strings", diff --git a/base/internal/type.post.h b/base/internal/type.post.h index ab220dc51..215493351 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -33,14 +33,6 @@ namespace cel { namespace base_internal { -inline internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type) { - return enum_type.TypeId(); -} - -inline internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type) { - return struct_type.TypeId(); -} - // Base implementation of persistent and transient handles for types. This // contains implementation details shared among both, but is never used // directly. The derived classes are responsible for defining appropriate @@ -85,12 +77,12 @@ class TypeHandleBase { // Called by `Transient` and `Persistent` to implement the same operator. friend bool operator==(const TypeHandleBase& lhs, const TypeHandleBase& rhs) { - const Type& lhs_type = ABSL_PREDICT_TRUE(static_cast(lhs)) - ? lhs.get() - : static_cast(NullType::Get()); - const Type& rhs_type = ABSL_PREDICT_TRUE(static_cast(rhs)) - ? rhs.get() - : static_cast(NullType::Get()); + if (static_cast(lhs) != static_cast(rhs) || + !static_cast(lhs)) { + return false; + } + const Type& lhs_type = lhs.get(); + const Type& rhs_type = rhs.get(); return lhs_type.Equals(rhs_type); } @@ -103,8 +95,6 @@ class TypeHandleBase { friend H AbslHashValue(H state, const TypeHandleBase& handle) { if (ABSL_PREDICT_TRUE(static_cast(handle))) { handle.get().HashValue(absl::HashState::Create(&state)); - } else { - NullType::Get().HashValue(absl::HashState::Create(&state)); } return state; } @@ -195,29 +185,6 @@ struct HandleTraits< } // namespace base_internal -#define CEL_INTERNAL_TYPE_DECL(name) \ - extern template class Persistent; \ - extern template class Persistent -CEL_INTERNAL_TYPE_DECL(Type); -CEL_INTERNAL_TYPE_DECL(NullType); -CEL_INTERNAL_TYPE_DECL(ErrorType); -CEL_INTERNAL_TYPE_DECL(DynType); -CEL_INTERNAL_TYPE_DECL(AnyType); -CEL_INTERNAL_TYPE_DECL(BoolType); -CEL_INTERNAL_TYPE_DECL(IntType); -CEL_INTERNAL_TYPE_DECL(UintType); -CEL_INTERNAL_TYPE_DECL(DoubleType); -CEL_INTERNAL_TYPE_DECL(BytesType); -CEL_INTERNAL_TYPE_DECL(StringType); -CEL_INTERNAL_TYPE_DECL(DurationType); -CEL_INTERNAL_TYPE_DECL(TimestampType); -CEL_INTERNAL_TYPE_DECL(EnumType); -CEL_INTERNAL_TYPE_DECL(StructType); -CEL_INTERNAL_TYPE_DECL(ListType); -CEL_INTERNAL_TYPE_DECL(MapType); -CEL_INTERNAL_TYPE_DECL(TypeType); -#undef CEL_INTERNAL_TYPE_DECL - } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index 559ad1628..80ecacc54 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -55,6 +55,14 @@ internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); } // namespace cel +#define CEL_INTERNAL_TYPE_DECL(name) \ + extern template class Persistent; \ + extern template class Persistent + +#define CEL_INTERNAL_TYPE_IMPL(name) \ + template class Persistent; \ + template class Persistent + #define CEL_INTERNAL_DECLARE_TYPE(base, derived) \ private: \ friend class ::cel::base_internal::TypeHandleBase; \ diff --git a/base/type.cc b/base/type.cc index 9169a5d44..4d1bb6a6f 100644 --- a/base/type.cc +++ b/base/type.cc @@ -17,38 +17,11 @@ #include #include -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" -#include "base/handle.h" -#include "base/type_manager.h" -#include "internal/casts.h" -#include "internal/no_destructor.h" +#include "absl/hash/hash.h" namespace cel { -#define CEL_INTERNAL_TYPE_IMPL(name) \ - template class Persistent; \ - template class Persistent CEL_INTERNAL_TYPE_IMPL(Type); -CEL_INTERNAL_TYPE_IMPL(NullType); -CEL_INTERNAL_TYPE_IMPL(ErrorType); -CEL_INTERNAL_TYPE_IMPL(DynType); -CEL_INTERNAL_TYPE_IMPL(AnyType); -CEL_INTERNAL_TYPE_IMPL(BoolType); -CEL_INTERNAL_TYPE_IMPL(IntType); -CEL_INTERNAL_TYPE_IMPL(UintType); -CEL_INTERNAL_TYPE_IMPL(DoubleType); -CEL_INTERNAL_TYPE_IMPL(BytesType); -CEL_INTERNAL_TYPE_IMPL(StringType); -CEL_INTERNAL_TYPE_IMPL(DurationType); -CEL_INTERNAL_TYPE_IMPL(TimestampType); -CEL_INTERNAL_TYPE_IMPL(EnumType); -CEL_INTERNAL_TYPE_IMPL(StructType); -CEL_INTERNAL_TYPE_IMPL(ListType); -CEL_INTERNAL_TYPE_IMPL(MapType); -CEL_INTERNAL_TYPE_IMPL(TypeType); -#undef CEL_INTERNAL_TYPE_IMPL std::string Type::DebugString() const { return std::string(name()); } @@ -65,139 +38,4 @@ void Type::HashValue(absl::HashState state) const { absl::HashState::combine(std::move(state), kind(), name()); } -const NullType& NullType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const ErrorType& ErrorType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const DynType& DynType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const AnyType& AnyType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const BoolType& BoolType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const IntType& IntType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const UintType& UintType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const DoubleType& DoubleType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const StringType& StringType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const BytesType& BytesType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const DurationType& DurationType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -const TimestampType& TimestampType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -struct EnumType::FindConstantVisitor final { - const EnumType& enum_type; - - absl::StatusOr operator()(absl::string_view name) const { - return enum_type.FindConstantByName(name); - } - - absl::StatusOr operator()(int64_t number) const { - return enum_type.FindConstantByNumber(number); - } -}; - -absl::StatusOr EnumType::FindConstant(ConstantId id) const { - return absl::visit(FindConstantVisitor{*this}, id.data_); -} - -struct StructType::FindFieldVisitor final { - const StructType& struct_type; - TypeManager& type_manager; - - absl::StatusOr operator()(absl::string_view name) const { - return struct_type.FindFieldByName(type_manager, name); - } - - absl::StatusOr operator()(int64_t number) const { - return struct_type.FindFieldByNumber(type_manager, number); - } -}; - -absl::StatusOr StructType::FindField( - TypeManager& type_manager, FieldId id) const { - return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); -} - -std::string ListType::DebugString() const { - return absl::StrCat(name(), "(", element()->DebugString(), ")"); -} - -bool ListType::Equals(const Type& other) const { - if (kind() != other.kind()) { - return false; - } - return element() == internal::down_cast(other).element(); -} - -void ListType::HashValue(absl::HashState state) const { - // We specifically hash the element first and then call the parent method to - // avoid hash suffix/prefix collisions. - Type::HashValue(absl::HashState::combine(std::move(state), element())); -} - -std::string MapType::DebugString() const { - return absl::StrCat(name(), "(", key()->DebugString(), ", ", - value()->DebugString(), ")"); -} - -bool MapType::Equals(const Type& other) const { - if (kind() != other.kind()) { - return false; - } - return key() == internal::down_cast(other).key() && - value() == internal::down_cast(other).value(); -} - -void MapType::HashValue(absl::HashState state) const { - // We specifically hash the element first and then call the parent method to - // avoid hash suffix/prefix collisions. - Type::HashValue(absl::HashState::combine(std::move(state), key(), value())); -} - -const TypeType& TypeType::Get() { - static const internal::NoDestructor instance; - return *instance; -} - } // namespace cel diff --git a/base/type.h b/base/type.h index 83d639cdb..f6d7d4f3c 100644 --- a/base/type.h +++ b/base/type.h @@ -18,19 +18,13 @@ #include #include -#include "absl/base/attributes.h" -#include "absl/base/macros.h" #include "absl/hash/hash.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "absl/types/variant.h" #include "base/handle.h" #include "base/internal/type.pre.h" // IWYU pragma: export #include "base/kind.h" #include "base/memory_manager.h" -#include "internal/casts.h" -#include "internal/rtti.h" namespace cel { @@ -133,609 +127,6 @@ class Type : public base_internal::Resource { using base_internal::Resource::Unref; }; -class NullType final : public Type { - public: - Kind kind() const override { return Kind::kNullType; } - - absl::string_view name() const override { return "null_type"; } - - // Note GCC does not consider a friend member as a member of a friend. - ABSL_ATTRIBUTE_PURE_FUNCTION static const NullType& Get(); - - private: - friend class NullValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kNullType; } - - NullType() = default; - - NullType(const NullType&) = delete; - NullType(NullType&&) = delete; -}; - -class ErrorType final : public Type { - public: - Kind kind() const override { return Kind::kError; } - - absl::string_view name() const override { return "*error*"; } - - private: - friend class ErrorValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kError; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const ErrorType& Get(); - - ErrorType() = default; - - ErrorType(const ErrorType&) = delete; - ErrorType(ErrorType&&) = delete; -}; - -class DynType final : public Type { - public: - Kind kind() const override { return Kind::kDyn; } - - absl::string_view name() const override { return "dyn"; } - - private: - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDyn; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const DynType& Get(); - - DynType() = default; - - DynType(const DynType&) = delete; - DynType(DynType&&) = delete; -}; - -class AnyType final : public Type { - public: - Kind kind() const override { return Kind::kAny; } - - absl::string_view name() const override { return "google.protobuf.Any"; } - - private: - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kAny; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const AnyType& Get(); - - AnyType() = default; - - AnyType(const AnyType&) = delete; - AnyType(AnyType&&) = delete; -}; - -class BoolType final : public Type { - public: - Kind kind() const override { return Kind::kBool; } - - absl::string_view name() const override { return "bool"; } - - private: - friend class BoolValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kBool; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const BoolType& Get(); - - BoolType() = default; - - BoolType(const BoolType&) = delete; - BoolType(BoolType&&) = delete; -}; - -class IntType final : public Type { - public: - Kind kind() const override { return Kind::kInt; } - - absl::string_view name() const override { return "int"; } - - private: - friend class IntValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kInt; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const IntType& Get(); - - IntType() = default; - - IntType(const IntType&) = delete; - IntType(IntType&&) = delete; -}; - -class UintType final : public Type { - public: - Kind kind() const override { return Kind::kUint; } - - absl::string_view name() const override { return "uint"; } - - private: - friend class UintValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kUint; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const UintType& Get(); - - UintType() = default; - - UintType(const UintType&) = delete; - UintType(UintType&&) = delete; -}; - -class DoubleType final : public Type { - public: - Kind kind() const override { return Kind::kDouble; } - - absl::string_view name() const override { return "double"; } - - private: - friend class DoubleValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDouble; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const DoubleType& Get(); - - DoubleType() = default; - - DoubleType(const DoubleType&) = delete; - DoubleType(DoubleType&&) = delete; -}; - -class StringType final : public Type { - public: - Kind kind() const override { return Kind::kString; } - - absl::string_view name() const override { return "string"; } - - private: - friend class StringValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kString; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const StringType& Get(); - - StringType() = default; - - StringType(const StringType&) = delete; - StringType(StringType&&) = delete; -}; - -class BytesType final : public Type { - public: - Kind kind() const override { return Kind::kBytes; } - - absl::string_view name() const override { return "bytes"; } - - private: - friend class BytesValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kBytes; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const BytesType& Get(); - - BytesType() = default; - - BytesType(const BytesType&) = delete; - BytesType(BytesType&&) = delete; -}; - -class DurationType final : public Type { - public: - Kind kind() const override { return Kind::kDuration; } - - absl::string_view name() const override { return "google.protobuf.Duration"; } - - private: - friend class DurationValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDuration; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const DurationType& Get(); - - DurationType() = default; - - DurationType(const DurationType&) = delete; - DurationType(DurationType&&) = delete; -}; - -class TimestampType final : public Type { - public: - Kind kind() const override { return Kind::kTimestamp; } - - absl::string_view name() const override { - return "google.protobuf.Timestamp"; - } - - private: - friend class TimestampValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kTimestamp; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const TimestampType& Get(); - - TimestampType() = default; - - TimestampType(const TimestampType&) = delete; - TimestampType(TimestampType&&) = delete; -}; - -// EnumType represents an enumeration type. An enumeration is a set of constants -// that can be looked up by name and/or number. -class EnumType : public Type { - public: - struct Constant; - - class ConstantId final { - public: - explicit ConstantId(absl::string_view name) - : data_(absl::in_place_type, name) {} - - explicit ConstantId(int64_t number) - : data_(absl::in_place_type, number) {} - - ConstantId() = delete; - - ConstantId(const ConstantId&) = default; - ConstantId& operator=(const ConstantId&) = default; - - private: - friend class EnumType; - friend class EnumValue; - - absl::variant data_; - }; - - Kind kind() const final { return Kind::kEnum; } - - // Find the constant definition for the given identifier. - absl::StatusOr FindConstant(ConstantId id) const; - - protected: - EnumType() = default; - - // Construct a new instance of EnumValue with a type of this. Called by - // EnumValue::New. - virtual absl::StatusOr> NewInstanceByName( - TypedEnumValueFactory& factory, absl::string_view name) const = 0; - - // Construct a new instance of EnumValue with a type of this. Called by - // EnumValue::New. - virtual absl::StatusOr> NewInstanceByNumber( - TypedEnumValueFactory& factory, int64_t number) const = 0; - - // Called by FindConstant. - virtual absl::StatusOr FindConstantByName( - absl::string_view name) const = 0; - - // Called by FindConstant. - virtual absl::StatusOr FindConstantByNumber( - int64_t number) const = 0; - - private: - friend internal::TypeInfo base_internal::GetEnumTypeTypeId( - const EnumType& enum_type); - struct NewInstanceVisitor; - struct FindConstantVisitor; - - friend struct NewInstanceVisitor; - friend struct FindConstantVisitor; - friend class EnumValue; - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kEnum; } - - EnumType(const EnumType&) = delete; - EnumType(EnumType&&) = delete; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_ENUM_TYPE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; -}; - -// CEL_DECLARE_ENUM_TYPE declares `enum_type` as an enumeration type. It must be -// part of the class definition of `enum_type`. -// -// class MyEnumType : public cel::EnumType { -// ... -// private: -// CEL_DECLARE_ENUM_TYPE(MyEnumType); -// }; -#define CEL_DECLARE_ENUM_TYPE(enum_type) \ - CEL_INTERNAL_DECLARE_TYPE(Enum, enum_type) - -// CEL_IMPLEMENT_ENUM_TYPE implements `enum_type` as an enumeration type. It -// must be called after the class definition of `enum_type`. -// -// class MyEnumType : public cel::EnumType { -// ... -// private: -// CEL_DECLARE_ENUM_TYPE(MyEnumType); -// }; -// -// CEL_IMPLEMENT_ENUM_TYPE(MyEnumType); -#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ - CEL_INTERNAL_IMPLEMENT_TYPE(Enum, enum_type) - -// StructType represents an struct type. An struct is a set of fields -// that can be looked up by name and/or number. -class StructType : public Type { - public: - struct Field; - - class FieldId final { - public: - explicit FieldId(absl::string_view name) - : data_(absl::in_place_type, name) {} - - explicit FieldId(int64_t number) - : data_(absl::in_place_type, number) {} - - FieldId() = delete; - - FieldId(const FieldId&) = default; - FieldId& operator=(const FieldId&) = default; - - private: - friend class StructType; - friend class StructValue; - - absl::variant data_; - }; - - Kind kind() const final { return Kind::kStruct; } - - // Find the field definition for the given identifier. - absl::StatusOr FindField(TypeManager& type_manager, FieldId id) const; - - protected: - StructType() = default; - - virtual absl::StatusOr> NewInstance( - TypedStructValueFactory& factory) const = 0; - - // Called by FindField. - virtual absl::StatusOr FindFieldByName( - TypeManager& type_manager, absl::string_view name) const = 0; - - // Called by FindField. - virtual absl::StatusOr FindFieldByNumber(TypeManager& type_manager, - int64_t number) const = 0; - - private: - friend internal::TypeInfo base_internal::GetStructTypeTypeId( - const StructType& struct_type); - struct FindFieldVisitor; - - friend struct FindFieldVisitor; - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - friend class StructValue; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kStruct; } - - StructType(const StructType&) = delete; - StructType(StructType&&) = delete; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_STRUCT_TYPE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; -}; - -// CEL_DECLARE_STRUCT_TYPE declares `struct_type` as an struct type. It must be -// part of the class definition of `struct_type`. -// -// class MyStructType : public cel::StructType { -// ... -// private: -// CEL_DECLARE_STRUCT_TYPE(MyStructType); -// }; -#define CEL_DECLARE_STRUCT_TYPE(struct_type) \ - CEL_INTERNAL_DECLARE_TYPE(Struct, struct_type) - -// CEL_IMPLEMENT_ENUM_TYPE implements `struct_type` as an struct type. It -// must be called after the class definition of `struct_type`. -// -// class MyStructType : public cel::StructType { -// ... -// private: -// CEL_DECLARE_STRUCT_TYPE(MyStructType); -// }; -// -// CEL_IMPLEMENT_STRUCT_TYPE(MyStructType); -#define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ - CEL_INTERNAL_IMPLEMENT_TYPE(Struct, struct_type) - -// ListType represents a list type. A list is a sequential container where each -// element is the same type. -class ListType : public Type { - // I would have liked to make this class final, but we cannot instantiate - // Persistent or Transient at this point. It must be - // done after the post include below. Maybe we should separate out the post - // includes on a per type basis so we can do that? - public: - Kind kind() const final { return Kind::kList; } - - absl::string_view name() const final { return "list"; } - - std::string DebugString() const final; - - // Returns the type of the elements in the list. - virtual Persistent element() const = 0; - - private: - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - friend class base_internal::ListTypeImpl; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kList; } - - ListType() = default; - - ListType(const ListType&) = delete; - ListType(ListType&&) = delete; - - std::pair SizeAndAlignment() const override = 0; - - // Called by base_internal::TypeHandleBase. - bool Equals(const Type& other) const final; - - // Called by base_internal::TypeHandleBase. - void HashValue(absl::HashState state) const final; -}; - -// MapType represents a map type. A map is container of key and value pairs -// where each key appears at most once. -class MapType : public Type { - // I would have liked to make this class final, but we cannot instantiate - // Persistent or Transient at this point. It must be - // done after the post include below. Maybe we should separate out the post - // includes on a per type basis so we can do that? - public: - Kind kind() const final { return Kind::kMap; } - - absl::string_view name() const final { return "map"; } - - std::string DebugString() const final; - - // Returns the type of the keys in the map. - virtual Persistent key() const = 0; - - // Returns the type of the values in the map. - virtual Persistent value() const = 0; - - private: - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - friend class base_internal::MapTypeImpl; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kMap; } - - MapType() = default; - - MapType(const MapType&) = delete; - MapType(MapType&&) = delete; - - std::pair SizeAndAlignment() const override = 0; - - // Called by base_internal::TypeHandleBase. - bool Equals(const Type& other) const final; - - // Called by base_internal::TypeHandleBase. - void HashValue(absl::HashState state) const final; -}; - -// TypeType represents the type of a type. -class TypeType final : public Type { - public: - Kind kind() const override { return Kind::kType; } - - absl::string_view name() const override { return "type"; } - - private: - friend class TypeValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kType; } - - ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeType& Get(); - - TypeType() = default; - - TypeType(const TypeType&) = delete; - TypeType(TypeType&&) = delete; -}; - } // namespace cel // type.pre.h forward declares types so they can be friended above. The types @@ -746,28 +137,7 @@ class TypeType final : public Type { namespace cel { -struct EnumType::Constant final { - explicit Constant(absl::string_view name, int64_t number) - : name(name), number(number) {} - - // The unqualified enumeration value name. - absl::string_view name; - // The enumeration value number. - int64_t number; -}; - -struct StructType::Field final { - explicit Field(absl::string_view name, int64_t number, - Persistent type) - : name(name), number(number), type(std::move(type)) {} - - // The field name. - absl::string_view name; - // The field number. - int64_t number; - // The field type; - Persistent type; -}; +CEL_INTERNAL_TYPE_DECL(Type); } // namespace cel diff --git a/base/type_factory.cc b/base/type_factory.cc index 2202782c3..5ad1ef789 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -20,7 +20,6 @@ #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "base/handle.h" -#include "base/type.h" namespace cel { diff --git a/base/type_factory.h b/base/type_factory.h index af90fa990..02cc1674a 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -24,7 +24,23 @@ #include "absl/synchronization/mutex.h" #include "base/handle.h" #include "base/memory_manager.h" -#include "base/type.h" +#include "base/types/any_type.h" +#include "base/types/bool_type.h" +#include "base/types/bytes_type.h" +#include "base/types/double_type.h" +#include "base/types/duration_type.h" +#include "base/types/dyn_type.h" +#include "base/types/enum_type.h" +#include "base/types/error_type.h" +#include "base/types/int_type.h" +#include "base/types/list_type.h" +#include "base/types/map_type.h" +#include "base/types/null_type.h" +#include "base/types/string_type.h" +#include "base/types/struct_type.h" +#include "base/types/timestamp_type.h" +#include "base/types/type_type.h" +#include "base/types/uint_type.h" namespace cel { diff --git a/base/type_manager.h b/base/type_manager.h index bbeea1b3e..2d2af1bc1 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -20,11 +20,11 @@ #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "base/type.h" #include "base/type_factory.h" #include "base/type_provider.h" -#include "base/type_registry.h" namespace cel { diff --git a/base/type_test.cc b/base/type_test.cc index 2d0db3b15..50e0d122a 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -204,7 +204,7 @@ TEST_P(TypeTest, MoveConstructor) { Persistent from(type_factory.GetIntType()); Persistent to(std::move(from)); IS_INITIALIZED(from); - EXPECT_EQ(from, type_factory.GetNullType()); + EXPECT_FALSE(from); EXPECT_EQ(to, type_factory.GetIntType()); } @@ -221,7 +221,7 @@ TEST_P(TypeTest, MoveAssignment) { Persistent to(type_factory.GetNullType()); to = std::move(from); IS_INITIALIZED(from); - EXPECT_EQ(from, type_factory.GetNullType()); + EXPECT_FALSE(from); EXPECT_EQ(to, type_factory.GetIntType()); } diff --git a/base/types/BUILD b/base/types/BUILD new file mode 100644 index 000000000..c82f9af25 --- /dev/null +++ b/base/types/BUILD @@ -0,0 +1,241 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "any", + srcs = ["any_type.cc"], + hdrs = ["any_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bool", + srcs = ["bool_type.cc"], + hdrs = ["bool_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bytes", + srcs = ["bytes_type.cc"], + hdrs = ["bytes_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "double", + srcs = ["double_type.cc"], + hdrs = ["double_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "duration", + srcs = ["duration_type.cc"], + hdrs = ["duration_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "dyn", + srcs = ["dyn_type.cc"], + hdrs = ["dyn_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "enum", + srcs = ["enum_type.cc"], + hdrs = ["enum_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:rtti", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "error", + srcs = ["error_type.cc"], + hdrs = ["error_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "int", + srcs = ["int_type.cc"], + hdrs = ["int_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "list", + srcs = ["list_type.cc"], + hdrs = ["list_type.h"], + deps = [ + "//base:kind", + "//base:type", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "map", + srcs = ["map_type.cc"], + hdrs = ["map_type.h"], + deps = [ + "//base:kind", + "//base:type", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "null", + srcs = ["null_type.cc"], + hdrs = ["null_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "string", + srcs = ["string_type.cc"], + hdrs = ["string_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "struct", + srcs = ["struct_type.cc"], + hdrs = ["struct_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:rtti", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "timestamp", + srcs = ["timestamp_type.cc"], + hdrs = ["timestamp_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "type", + srcs = ["type_type.cc"], + hdrs = ["type_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "uint", + srcs = ["uint_type.cc"], + hdrs = ["uint_type.h"], + deps = [ + "//base:kind", + "//base:type", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) diff --git a/base/types/any_type.cc b/base/types/any_type.cc new file mode 100644 index 000000000..e1ba938da --- /dev/null +++ b/base/types/any_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/any_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(AnyType); + +const AnyType& AnyType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/any_type.h b/base/types/any_type.h new file mode 100644 index 000000000..5ba77622e --- /dev/null +++ b/base/types/any_type.h @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ANY_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_ANY_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class AnyType final : public Type { + public: + Kind kind() const override { return Kind::kAny; } + + absl::string_view name() const override { return "google.protobuf.Any"; } + + private: + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kAny; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const AnyType& Get(); + + AnyType() = default; + + AnyType(const AnyType&) = delete; + AnyType(AnyType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(AnyType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_ANY_TYPE_H_ diff --git a/base/types/bool_type.cc b/base/types/bool_type.cc new file mode 100644 index 000000000..6b81034e7 --- /dev/null +++ b/base/types/bool_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/bool_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(BoolType); + +const BoolType& BoolType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/bool_type.h b/base/types/bool_type.h new file mode 100644 index 000000000..cee2a43f6 --- /dev/null +++ b/base/types/bool_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_BOOL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_BOOL_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class BoolType final : public Type { + public: + Kind kind() const override { return Kind::kBool; } + + absl::string_view name() const override { return "bool"; } + + private: + friend class BoolValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kBool; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const BoolType& Get(); + + BoolType() = default; + + BoolType(const BoolType&) = delete; + BoolType(BoolType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(BoolType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_BOOL_TYPE_H_ diff --git a/base/types/bytes_type.cc b/base/types/bytes_type.cc new file mode 100644 index 000000000..f174af975 --- /dev/null +++ b/base/types/bytes_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/bytes_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(BytesType); + +const BytesType& BytesType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/bytes_type.h b/base/types/bytes_type.h new file mode 100644 index 000000000..412f684d7 --- /dev/null +++ b/base/types/bytes_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_BYTES_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_BYTES_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class BytesType final : public Type { + public: + Kind kind() const override { return Kind::kBytes; } + + absl::string_view name() const override { return "bytes"; } + + private: + friend class BytesValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kBytes; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const BytesType& Get(); + + BytesType() = default; + + BytesType(const BytesType&) = delete; + BytesType(BytesType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(BytesType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_BYTES_TYPE_H_ diff --git a/base/types/double_type.cc b/base/types/double_type.cc new file mode 100644 index 000000000..8735f51b4 --- /dev/null +++ b/base/types/double_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/double_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(DoubleType); + +const DoubleType& DoubleType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/double_type.h b/base/types/double_type.h new file mode 100644 index 000000000..946cbe080 --- /dev/null +++ b/base/types/double_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DOUBLE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_DOUBLE_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class DoubleType final : public Type { + public: + Kind kind() const override { return Kind::kDouble; } + + absl::string_view name() const override { return "double"; } + + private: + friend class DoubleValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDouble; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const DoubleType& Get(); + + DoubleType() = default; + + DoubleType(const DoubleType&) = delete; + DoubleType(DoubleType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(DoubleType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_DOUBLE_TYPE_H_ diff --git a/base/types/duration_type.cc b/base/types/duration_type.cc new file mode 100644 index 000000000..f7617b722 --- /dev/null +++ b/base/types/duration_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/duration_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(DurationType); + +const DurationType& DurationType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/duration_type.h b/base/types/duration_type.h new file mode 100644 index 000000000..b6855751a --- /dev/null +++ b/base/types/duration_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DURATION_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_DURATION_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class DurationType final : public Type { + public: + Kind kind() const override { return Kind::kDuration; } + + absl::string_view name() const override { return "google.protobuf.Duration"; } + + private: + friend class DurationValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDuration; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const DurationType& Get(); + + DurationType() = default; + + DurationType(const DurationType&) = delete; + DurationType(DurationType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(DurationType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_DURATION_TYPE_H_ diff --git a/base/types/dyn_type.cc b/base/types/dyn_type.cc new file mode 100644 index 000000000..c9a12e3ae --- /dev/null +++ b/base/types/dyn_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/dyn_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(DynType); + +const DynType& DynType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/dyn_type.h b/base/types/dyn_type.h new file mode 100644 index 000000000..7e5439ddd --- /dev/null +++ b/base/types/dyn_type.h @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class DynType final : public Type { + public: + Kind kind() const override { return Kind::kDyn; } + + absl::string_view name() const override { return "dyn"; } + + private: + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDyn; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const DynType& Get(); + + DynType() = default; + + DynType(const DynType&) = delete; + DynType(DynType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(DynType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ diff --git a/base/types/enum_type.cc b/base/types/enum_type.cc new file mode 100644 index 000000000..ad6dc0371 --- /dev/null +++ b/base/types/enum_type.cc @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/enum_type.h" + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(EnumType); + +struct EnumType::FindConstantVisitor final { + const EnumType& enum_type; + + absl::StatusOr operator()(absl::string_view name) const { + return enum_type.FindConstantByName(name); + } + + absl::StatusOr operator()(int64_t number) const { + return enum_type.FindConstantByNumber(number); + } +}; + +absl::StatusOr EnumType::FindConstant(ConstantId id) const { + return absl::visit(FindConstantVisitor{*this}, id.data_); +} + +} // namespace cel diff --git a/base/types/enum_type.h b/base/types/enum_type.h new file mode 100644 index 000000000..500318751 --- /dev/null +++ b/base/types/enum_type.h @@ -0,0 +1,158 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ENUM_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_ENUM_TYPE_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "base/kind.h" +#include "base/type.h" +#include "internal/rtti.h" + +namespace cel { + +class TypedEnumValueFactory; +class TypeManager; + +// EnumType represents an enumeration type. An enumeration is a set of constants +// that can be looked up by name and/or number. +class EnumType : public Type { + public: + struct Constant; + + class ConstantId final { + public: + explicit ConstantId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit ConstantId(int64_t number) + : data_(absl::in_place_type, number) {} + + ConstantId() = delete; + + ConstantId(const ConstantId&) = default; + ConstantId& operator=(const ConstantId&) = default; + + private: + friend class EnumType; + friend class EnumValue; + + absl::variant data_; + }; + + Kind kind() const final { return Kind::kEnum; } + + // Find the constant definition for the given identifier. + absl::StatusOr FindConstant(ConstantId id) const; + + protected: + EnumType() = default; + + // Construct a new instance of EnumValue with a type of this. Called by + // EnumValue::New. + virtual absl::StatusOr> NewInstanceByName( + TypedEnumValueFactory& factory, absl::string_view name) const = 0; + + // Construct a new instance of EnumValue with a type of this. Called by + // EnumValue::New. + virtual absl::StatusOr> NewInstanceByNumber( + TypedEnumValueFactory& factory, int64_t number) const = 0; + + // Called by FindConstant. + virtual absl::StatusOr FindConstantByName( + absl::string_view name) const = 0; + + // Called by FindConstant. + virtual absl::StatusOr FindConstantByNumber( + int64_t number) const = 0; + + private: + friend internal::TypeInfo base_internal::GetEnumTypeTypeId( + const EnumType& enum_type); + struct NewInstanceVisitor; + struct FindConstantVisitor; + + friend struct NewInstanceVisitor; + friend struct FindConstantVisitor; + friend class EnumValue; + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kEnum; } + + EnumType(const EnumType&) = delete; + EnumType(EnumType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_TYPE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; +}; + +// CEL_DECLARE_ENUM_TYPE declares `enum_type` as an enumeration type. It must be +// part of the class definition of `enum_type`. +// +// class MyEnumType : public cel::EnumType { +// ... +// private: +// CEL_DECLARE_ENUM_TYPE(MyEnumType); +// }; +#define CEL_DECLARE_ENUM_TYPE(enum_type) \ + CEL_INTERNAL_DECLARE_TYPE(Enum, enum_type) + +// CEL_IMPLEMENT_ENUM_TYPE implements `enum_type` as an enumeration type. It +// must be called after the class definition of `enum_type`. +// +// class MyEnumType : public cel::EnumType { +// ... +// private: +// CEL_DECLARE_ENUM_TYPE(MyEnumType); +// }; +// +// CEL_IMPLEMENT_ENUM_TYPE(MyEnumType); +#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ + CEL_INTERNAL_IMPLEMENT_TYPE(Enum, enum_type) + +struct EnumType::Constant final { + explicit Constant(absl::string_view name, int64_t number) + : name(name), number(number) {} + + // The unqualified enumeration value name. + absl::string_view name; + // The enumeration value number. + int64_t number; +}; + +CEL_INTERNAL_TYPE_DECL(EnumType); + +namespace base_internal { + +inline internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type) { + return enum_type.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_ENUM_TYPE_H_ diff --git a/base/types/error_type.cc b/base/types/error_type.cc new file mode 100644 index 000000000..f45c69d31 --- /dev/null +++ b/base/types/error_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/error_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(ErrorType); + +const ErrorType& ErrorType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/error_type.h b/base/types/error_type.h new file mode 100644 index 000000000..96059aa0e --- /dev/null +++ b/base/types/error_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ERROR_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_ERROR_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class ErrorType final : public Type { + public: + Kind kind() const override { return Kind::kError; } + + absl::string_view name() const override { return "*error*"; } + + private: + friend class ErrorValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kError; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const ErrorType& Get(); + + ErrorType() = default; + + ErrorType(const ErrorType&) = delete; + ErrorType(ErrorType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(ErrorType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_ERROR_TYPE_H_ diff --git a/base/types/int_type.cc b/base/types/int_type.cc new file mode 100644 index 000000000..e722f8f31 --- /dev/null +++ b/base/types/int_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/int_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(IntType); + +const IntType& IntType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/int_type.h b/base/types/int_type.h new file mode 100644 index 000000000..579f41859 --- /dev/null +++ b/base/types/int_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_INT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_INT_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class IntType final : public Type { + public: + Kind kind() const override { return Kind::kInt; } + + absl::string_view name() const override { return "int"; } + + private: + friend class IntValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kInt; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const IntType& Get(); + + IntType() = default; + + IntType(const IntType&) = delete; + IntType(IntType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(IntType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_INT_TYPE_H_ diff --git a/base/types/list_type.cc b/base/types/list_type.cc new file mode 100644 index 000000000..4f502fbf4 --- /dev/null +++ b/base/types/list_type.cc @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/list_type.h" + +#include "absl/strings/str_cat.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(ListType); + +std::string ListType::DebugString() const { + return absl::StrCat(name(), "(", element()->DebugString(), ")"); +} + +bool ListType::Equals(const Type& other) const { + if (kind() != other.kind()) { + return false; + } + return element() == internal::down_cast(other).element(); +} + +void ListType::HashValue(absl::HashState state) const { + // We specifically hash the element first and then call the parent method to + // avoid hash suffix/prefix collisions. + Type::HashValue(absl::HashState::combine(std::move(state), element())); +} + +} // namespace cel diff --git a/base/types/list_type.h b/base/types/list_type.h new file mode 100644 index 000000000..d339c6229 --- /dev/null +++ b/base/types/list_type.h @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_LIST_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_LIST_TYPE_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +// ListType represents a list type. A list is a sequential container where each +// element is the same type. +class ListType : public Type { + // I would have liked to make this class final, but we cannot instantiate + // Persistent or Transient at this point. It must be + // done after the post include below. Maybe we should separate out the post + // includes on a per type basis so we can do that? + public: + Kind kind() const final { return Kind::kList; } + + absl::string_view name() const final { return "list"; } + + std::string DebugString() const final; + + // Returns the type of the elements in the list. + virtual Persistent element() const = 0; + + private: + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + friend class base_internal::ListTypeImpl; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kList; } + + ListType() = default; + + ListType(const ListType&) = delete; + ListType(ListType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by base_internal::TypeHandleBase. + bool Equals(const Type& other) const final; + + // Called by base_internal::TypeHandleBase. + void HashValue(absl::HashState state) const final; +}; + +CEL_INTERNAL_TYPE_DECL(ListType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_LIST_TYPE_H_ diff --git a/base/types/map_type.cc b/base/types/map_type.cc new file mode 100644 index 000000000..a55f275ca --- /dev/null +++ b/base/types/map_type.cc @@ -0,0 +1,42 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/map_type.h" + +#include "absl/strings/str_cat.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(MapType); + +std::string MapType::DebugString() const { + return absl::StrCat(name(), "(", key()->DebugString(), ", ", + value()->DebugString(), ")"); +} + +bool MapType::Equals(const Type& other) const { + if (kind() != other.kind()) { + return false; + } + return key() == internal::down_cast(other).key() && + value() == internal::down_cast(other).value(); +} + +void MapType::HashValue(absl::HashState state) const { + // We specifically hash the element first and then call the parent method to + // avoid hash suffix/prefix collisions. + Type::HashValue(absl::HashState::combine(std::move(state), key(), value())); +} + +} // namespace cel diff --git a/base/types/map_type.h b/base/types/map_type.h new file mode 100644 index 000000000..9da1cb0f6 --- /dev/null +++ b/base/types/map_type.h @@ -0,0 +1,78 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_MAP_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_MAP_TYPE_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class TypeFactory; + +// MapType represents a map type. A map is container of key and value pairs +// where each key appears at most once. +class MapType : public Type { + // I would have liked to make this class final, but we cannot instantiate + // Persistent or Transient at this point. It must be + // done after the post include below. Maybe we should separate out the post + // includes on a per type basis so we can do that? + public: + Kind kind() const final { return Kind::kMap; } + + absl::string_view name() const final { return "map"; } + + std::string DebugString() const final; + + // Returns the type of the keys in the map. + virtual Persistent key() const = 0; + + // Returns the type of the values in the map. + virtual Persistent value() const = 0; + + private: + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + friend class base_internal::MapTypeImpl; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kMap; } + + MapType() = default; + + MapType(const MapType&) = delete; + MapType(MapType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by base_internal::TypeHandleBase. + bool Equals(const Type& other) const final; + + // Called by base_internal::TypeHandleBase. + void HashValue(absl::HashState state) const final; +}; + +CEL_INTERNAL_TYPE_DECL(MapType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_MAP_TYPE_H_ diff --git a/base/types/null_type.cc b/base/types/null_type.cc new file mode 100644 index 000000000..b964cd2f1 --- /dev/null +++ b/base/types/null_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/null_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(NullType); + +const NullType& NullType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/null_type.h b/base/types/null_type.h new file mode 100644 index 000000000..4544f73d5 --- /dev/null +++ b/base/types/null_type.h @@ -0,0 +1,58 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_NULL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_NULL_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class NullType final : public Type { + public: + Kind kind() const override { return Kind::kNullType; } + + absl::string_view name() const override { return "null_type"; } + + // Note GCC does not consider a friend member as a member of a friend. + ABSL_ATTRIBUTE_PURE_FUNCTION static const NullType& Get(); + + private: + friend class NullValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kNullType; } + + NullType() = default; + + NullType(const NullType&) = delete; + NullType(NullType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(NullType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_NULL_TYPE_H_ diff --git a/base/types/string_type.cc b/base/types/string_type.cc new file mode 100644 index 000000000..57be9ac16 --- /dev/null +++ b/base/types/string_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/string_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(StringType); + +const StringType& StringType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/string_type.h b/base/types/string_type.h new file mode 100644 index 000000000..eb75cdbf2 --- /dev/null +++ b/base/types/string_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_STRING_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_STRING_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class StringType final : public Type { + public: + Kind kind() const override { return Kind::kString; } + + absl::string_view name() const override { return "string"; } + + private: + friend class StringValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kString; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const StringType& Get(); + + StringType() = default; + + StringType(const StringType&) = delete; + StringType(StringType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(StringType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_STRING_TYPE_H_ diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc new file mode 100644 index 000000000..0b5c20ece --- /dev/null +++ b/base/types/struct_type.cc @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/struct_type.h" + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(StructType); + +struct StructType::FindFieldVisitor final { + const StructType& struct_type; + TypeManager& type_manager; + + absl::StatusOr operator()(absl::string_view name) const { + return struct_type.FindFieldByName(type_manager, name); + } + + absl::StatusOr operator()(int64_t number) const { + return struct_type.FindFieldByNumber(type_manager, number); + } +}; + +absl::StatusOr StructType::FindField( + TypeManager& type_manager, FieldId id) const { + return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); +} + +} // namespace cel diff --git a/base/types/struct_type.h b/base/types/struct_type.h new file mode 100644 index 000000000..ef9acfb7f --- /dev/null +++ b/base/types/struct_type.h @@ -0,0 +1,152 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_STRUCT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_STRUCT_TYPE_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "base/kind.h" +#include "base/type.h" +#include "internal/rtti.h" + +namespace cel { + +class TypedStructValueFactory; +class TypeManager; + +// StructType represents an struct type. An struct is a set of fields +// that can be looked up by name and/or number. +class StructType : public Type { + public: + struct Field; + + class FieldId final { + public: + explicit FieldId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit FieldId(int64_t number) + : data_(absl::in_place_type, number) {} + + FieldId() = delete; + + FieldId(const FieldId&) = default; + FieldId& operator=(const FieldId&) = default; + + private: + friend class StructType; + friend class StructValue; + + absl::variant data_; + }; + + Kind kind() const final { return Kind::kStruct; } + + // Find the field definition for the given identifier. + absl::StatusOr FindField(TypeManager& type_manager, FieldId id) const; + + protected: + StructType() = default; + + virtual absl::StatusOr> NewInstance( + TypedStructValueFactory& factory) const = 0; + + // Called by FindField. + virtual absl::StatusOr FindFieldByName( + TypeManager& type_manager, absl::string_view name) const = 0; + + // Called by FindField. + virtual absl::StatusOr FindFieldByNumber(TypeManager& type_manager, + int64_t number) const = 0; + + private: + friend internal::TypeInfo base_internal::GetStructTypeTypeId( + const StructType& struct_type); + struct FindFieldVisitor; + + friend struct FindFieldVisitor; + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + friend class StructValue; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kStruct; } + + StructType(const StructType&) = delete; + StructType(StructType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_STRUCT_TYPE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; +}; + +// CEL_DECLARE_STRUCT_TYPE declares `struct_type` as an struct type. It must be +// part of the class definition of `struct_type`. +// +// class MyStructType : public cel::StructType { +// ... +// private: +// CEL_DECLARE_STRUCT_TYPE(MyStructType); +// }; +#define CEL_DECLARE_STRUCT_TYPE(struct_type) \ + CEL_INTERNAL_DECLARE_TYPE(Struct, struct_type) + +// CEL_IMPLEMENT_ENUM_TYPE implements `struct_type` as an struct type. It +// must be called after the class definition of `struct_type`. +// +// class MyStructType : public cel::StructType { +// ... +// private: +// CEL_DECLARE_STRUCT_TYPE(MyStructType); +// }; +// +// CEL_IMPLEMENT_STRUCT_TYPE(MyStructType); +#define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ + CEL_INTERNAL_IMPLEMENT_TYPE(Struct, struct_type) + +struct StructType::Field final { + explicit Field(absl::string_view name, int64_t number, + Persistent type) + : name(name), number(number), type(std::move(type)) {} + + // The field name. + absl::string_view name; + // The field number. + int64_t number; + // The field type; + Persistent type; +}; + +CEL_INTERNAL_TYPE_DECL(StructType); + +namespace base_internal { + +inline internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type) { + return struct_type.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_STRUCT_TYPE_H_ diff --git a/base/types/timestamp_type.cc b/base/types/timestamp_type.cc new file mode 100644 index 000000000..7b1239689 --- /dev/null +++ b/base/types/timestamp_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/timestamp_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(TimestampType); + +const TimestampType& TimestampType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/timestamp_type.h b/base/types/timestamp_type.h new file mode 100644 index 000000000..20c7209b5 --- /dev/null +++ b/base/types/timestamp_type.h @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_TIMESTAMP_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_TIMESTAMP_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class TimestampType final : public Type { + public: + Kind kind() const override { return Kind::kTimestamp; } + + absl::string_view name() const override { + return "google.protobuf.Timestamp"; + } + + private: + friend class TimestampValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kTimestamp; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const TimestampType& Get(); + + TimestampType() = default; + + TimestampType(const TimestampType&) = delete; + TimestampType(TimestampType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(TimestampType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_TIMESTAMP_TYPE_H_ diff --git a/base/types/type_type.cc b/base/types/type_type.cc new file mode 100644 index 000000000..a6468b5a7 --- /dev/null +++ b/base/types/type_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/type_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(TypeType); + +const TypeType& TypeType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/type_type.h b/base/types/type_type.h new file mode 100644 index 000000000..711db96a0 --- /dev/null +++ b/base/types/type_type.h @@ -0,0 +1,58 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_TYPE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_TYPE_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +// TypeType represents the type of a type. +class TypeType final : public Type { + public: + Kind kind() const override { return Kind::kType; } + + absl::string_view name() const override { return "type"; } + + private: + friend class TypeValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kType; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeType& Get(); + + TypeType() = default; + + TypeType(const TypeType&) = delete; + TypeType(TypeType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(TypeType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_TYPE_TYPE_H_ diff --git a/base/types/uint_type.cc b/base/types/uint_type.cc new file mode 100644 index 000000000..1632b1d4c --- /dev/null +++ b/base/types/uint_type.cc @@ -0,0 +1,28 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/uint_type.h" + +#include "internal/no_destructor.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(UintType); + +const UintType& UintType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/types/uint_type.h b/base/types/uint_type.h new file mode 100644 index 000000000..1af31d0f1 --- /dev/null +++ b/base/types/uint_type.h @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_UINT_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_UINT_TYPE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class UintType final : public Type { + public: + Kind kind() const override { return Kind::kUint; } + + absl::string_view name() const override { return "uint"; } + + private: + friend class UintValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kUint; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const UintType& Get(); + + UintType() = default; + + UintType(const UintType&) = delete; + UintType(UintType&&) = delete; +}; + +CEL_INTERNAL_TYPE_DECL(UintType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_UINT_TYPE_H_ diff --git a/base/value.h b/base/value.h index 2f2f1be7a..6cc73e78a 100644 --- a/base/value.h +++ b/base/value.h @@ -35,6 +35,23 @@ #include "base/kind.h" #include "base/memory_manager.h" #include "base/type.h" +#include "base/types/any_type.h" +#include "base/types/bool_type.h" +#include "base/types/bytes_type.h" +#include "base/types/double_type.h" +#include "base/types/duration_type.h" +#include "base/types/dyn_type.h" +#include "base/types/enum_type.h" +#include "base/types/error_type.h" +#include "base/types/int_type.h" +#include "base/types/list_type.h" +#include "base/types/map_type.h" +#include "base/types/null_type.h" +#include "base/types/string_type.h" +#include "base/types/struct_type.h" +#include "base/types/timestamp_type.h" +#include "base/types/type_type.h" +#include "base/types/uint_type.h" #include "internal/casts.h" #include "internal/rtti.h" diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 3c950ec94..4244683df 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -192,7 +192,7 @@ cc_library( hdrs = ["legacy_type_provider.h"], deps = [ ":legacy_type_adapter", - "//base:type", + "//base:type_provider", "@com_google_absl//absl/types:optional", ], ) From e64eb547714bf8c5b8e84684b0222ab04d307fdc Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 8 Jun 2022 19:08:23 +0000 Subject: [PATCH 002/303] Internal change PiperOrigin-RevId: 453736806 --- base/BUILD | 63 +- base/internal/value.post.h | 262 +------- base/internal/value.pre.h | 16 + base/type_test.cc | 2 + base/value.cc | 1018 +------------------------------- base/value.h | 832 -------------------------- base/value_factory.cc | 141 +++++ base/value_factory.h | 15 + base/value_test.cc | 6 +- base/values/BUILD | 259 ++++++++ base/values/bool_value.cc | 57 ++ base/values/bool_value.h | 71 +++ base/values/bytes_value.cc | 306 ++++++++++ base/values/bytes_value.h | 205 +++++++ base/values/double_value.cc | 83 +++ base/values/double_value.h | 75 +++ base/values/duration_value.cc | 58 ++ base/values/duration_value.h | 71 +++ base/values/enum_value.cc | 33 ++ base/values/enum_value.h | 116 ++++ base/values/error_value.cc | 86 +++ base/values/error_value.h | 67 +++ base/values/int_value.cc | 56 ++ base/values/int_value.h | 68 +++ base/values/list_value.cc | 21 + base/values/list_value.h | 118 ++++ base/values/map_value.cc | 21 + base/values/map_value.h | 117 ++++ base/values/null_value.cc | 59 ++ base/values/null_value.h | 69 +++ base/values/string_value.cc | 320 ++++++++++ base/values/string_value.h | 226 +++++++ base/values/struct_value.cc | 77 +++ base/values/struct_value.h | 145 +++++ base/values/timestamp_value.cc | 58 ++ base/values/timestamp_value.h | 74 +++ base/values/type_value.cc | 55 ++ base/values/type_value.h | 67 +++ base/values/uint_value.cc | 58 ++ base/values/uint_value.h | 66 +++ 40 files changed, 3380 insertions(+), 2137 deletions(-) create mode 100644 base/values/BUILD create mode 100644 base/values/bool_value.cc create mode 100644 base/values/bool_value.h create mode 100644 base/values/bytes_value.cc create mode 100644 base/values/bytes_value.h create mode 100644 base/values/double_value.cc create mode 100644 base/values/double_value.h create mode 100644 base/values/duration_value.cc create mode 100644 base/values/duration_value.h create mode 100644 base/values/enum_value.cc create mode 100644 base/values/enum_value.h create mode 100644 base/values/error_value.cc create mode 100644 base/values/error_value.h create mode 100644 base/values/int_value.cc create mode 100644 base/values/int_value.h create mode 100644 base/values/list_value.cc create mode 100644 base/values/list_value.h create mode 100644 base/values/map_value.cc create mode 100644 base/values/map_value.h create mode 100644 base/values/null_value.cc create mode 100644 base/values/null_value.h create mode 100644 base/values/string_value.cc create mode 100644 base/values/string_value.h create mode 100644 base/values/struct_value.cc create mode 100644 base/values/struct_value.h create mode 100644 base/values/timestamp_value.cc create mode 100644 base/values/timestamp_value.h create mode 100644 base/values/type_value.cc create mode 100644 base/values/type_value.h create mode 100644 base/values/uint_value.cc create mode 100644 base/values/uint_value.h diff --git a/base/BUILD b/base/BUILD index 052093a9a..3a1686404 100644 --- a/base/BUILD +++ b/base/BUILD @@ -203,6 +203,8 @@ cc_test( ":type_manager", ":value", "//base/internal:memory_manager_testing", + "//base/values:enum", + "//base/values:struct", "//internal:testing", "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", @@ -214,55 +216,55 @@ cc_library( name = "value", srcs = [ "value.cc", - "value_factory.cc", ], hdrs = [ "value.h", - "value_factory.h", ], deps = [ ":handle", ":kind", ":memory_manager", ":type", - ":type_manager", "//base/internal:value", - "//base/types:any", - "//base/types:bool", - "//base/types:bytes", - "//base/types:double", - "//base/types:duration", - "//base/types:dyn", - "//base/types:enum", - "//base/types:error", - "//base/types:int", - "//base/types:list", - "//base/types:map", - "//base/types:null", - "//base/types:string", - "//base/types:struct", - "//base/types:timestamp", - "//base/types:type", - "//base/types:uint", - "//internal:casts", - "//internal:no_destructor", - "//internal:rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "value_factory", + srcs = ["value_factory.cc"], + hdrs = ["value_factory.h"], + deps = [ + ":handle", + ":memory_manager", + ":type_manager", + ":value", + "//base/values:bool", + "//base/values:bytes", + "//base/values:double", + "//base/values:duration", + "//base/values:enum", + "//base/values:error", + "//base/values:int", + "//base/values:list", + "//base/values:map", + "//base/values:null", + "//base/values:string", + "//base/values:struct", + "//base/values:timestamp", + "//base/values:type", + "//base/values:uint", "//internal:status_macros", - "//internal:strings", "//internal:time", "//internal:utf8", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", ], ) @@ -278,6 +280,7 @@ cc_test( ":type_factory", ":type_manager", ":value", + ":value_factory", "//base/internal:memory_manager_testing", "//internal:strings", "//internal:testing", diff --git a/base/internal/value.post.h b/base/internal/value.post.h index 2e78ba07a..f6753c543 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -27,8 +27,6 @@ #include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/hash/hash.h" -#include "absl/numeric/bits.h" -#include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "base/handle.h" #include "internal/casts.h" @@ -37,224 +35,6 @@ namespace cel { namespace base_internal { -inline internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value) { - return enum_value.TypeId(); -} - -inline internal::TypeInfo GetStructValueTypeId( - const StructValue& struct_value) { - return struct_value.TypeId(); -} - -inline internal::TypeInfo GetListValueTypeId(const ListValue& list_value) { - return list_value.TypeId(); -} - -inline internal::TypeInfo GetMapValueTypeId(const MapValue& map_value) { - return map_value.TypeId(); -} - -// Implementation of BytesValue that is stored inlined within a handle. Since -// absl::Cord is reference counted itself, this is more efficient than storing -// this on the heap. -class InlinedCordBytesValue final : public BytesValue, public ResourceInlined { - private: - template - friend class ValueHandle; - - explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} - - InlinedCordBytesValue() = delete; - - InlinedCordBytesValue(const InlinedCordBytesValue&) = default; - InlinedCordBytesValue(InlinedCordBytesValue&&) = default; - - // See comments for respective member functions on `ByteValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - absl::Cord value_; -}; - -// Implementation of BytesValue that is stored inlined within a handle. This -// class is inheritently unsafe and care should be taken when using it. -// Typically this should only be used for empty strings or data that is static -// and lives for the duration of a program. -class InlinedStringViewBytesValue final : public BytesValue, - public ResourceInlined { - private: - template - friend class ValueHandle; - - explicit InlinedStringViewBytesValue(absl::string_view value) - : value_(value) {} - - InlinedStringViewBytesValue() = delete; - - InlinedStringViewBytesValue(const InlinedStringViewBytesValue&) = default; - InlinedStringViewBytesValue(InlinedStringViewBytesValue&&) = default; - - // See comments for respective member functions on `ByteValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - absl::string_view value_; -}; - -// Implementation of BytesValue that uses std::string and is allocated on the -// heap, potentially reference counted. -class StringBytesValue final : public BytesValue { - private: - friend class cel::MemoryManager; - - explicit StringBytesValue(std::string value) : value_(std::move(value)) {} - - StringBytesValue() = delete; - StringBytesValue(const StringBytesValue&) = delete; - StringBytesValue(StringBytesValue&&) = delete; - - // See comments for respective member functions on `ByteValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - std::string value_; -}; - -// Implementation of BytesValue that wraps a contiguous array of bytes and calls -// the releaser when it is no longer needed. It is stored on the heap and -// potentially reference counted. -class ExternalDataBytesValue final : public BytesValue { - private: - friend class cel::MemoryManager; - - explicit ExternalDataBytesValue(ExternalData value) - : value_(std::move(value)) {} - - ExternalDataBytesValue() = delete; - ExternalDataBytesValue(const ExternalDataBytesValue&) = delete; - ExternalDataBytesValue(ExternalDataBytesValue&&) = delete; - - // See comments for respective member functions on `ByteValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - ExternalData value_; -}; - -// Implementation of StringValue that is stored inlined within a handle. Since -// absl::Cord is reference counted itself, this is more efficient then storing -// this on the heap. -class InlinedCordStringValue final : public StringValue, - public ResourceInlined { - private: - template - friend class ValueHandle; - - explicit InlinedCordStringValue(absl::Cord value) - : InlinedCordStringValue(0, std::move(value)) {} - - InlinedCordStringValue(size_t size, absl::Cord value) - : StringValue(size), value_(std::move(value)) {} - - InlinedCordStringValue() = delete; - - InlinedCordStringValue(const InlinedCordStringValue&) = default; - InlinedCordStringValue(InlinedCordStringValue&&) = default; - - // See comments for respective member functions on `StringValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - absl::Cord value_; -}; - -// Implementation of StringValue that is stored inlined within a handle. This -// class is inheritently unsafe and care should be taken when using it. -// Typically this should only be used for empty strings or data that is static -// and lives for the duration of a program. -class InlinedStringViewStringValue final : public StringValue, - public ResourceInlined { - private: - template - friend class ValueHandle; - - explicit InlinedStringViewStringValue(absl::string_view value) - : InlinedStringViewStringValue(0, value) {} - - InlinedStringViewStringValue(size_t size, absl::string_view value) - : StringValue(size), value_(value) {} - - InlinedStringViewStringValue() = delete; - - InlinedStringViewStringValue(const InlinedStringViewStringValue&) = default; - InlinedStringViewStringValue(InlinedStringViewStringValue&&) = default; - - // See comments for respective member functions on `StringValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - absl::string_view value_; -}; - -// Implementation of StringValue that uses std::string and is allocated on the -// heap, potentially reference counted. -class StringStringValue final : public StringValue { - private: - friend class cel::MemoryManager; - - explicit StringStringValue(std::string value) - : StringStringValue(0, std::move(value)) {} - - StringStringValue(size_t size, std::string value) - : StringValue(size), value_(std::move(value)) {} - - StringStringValue() = delete; - StringStringValue(const StringStringValue&) = delete; - StringStringValue(StringStringValue&&) = delete; - - // See comments for respective member functions on `StringValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - std::string value_; -}; - -// Implementation of StringValue that wraps a contiguous array of bytes and -// calls the releaser when it is no longer needed. It is stored on the heap and -// potentially reference counted. -class ExternalDataStringValue final : public StringValue { - private: - friend class cel::MemoryManager; - - explicit ExternalDataStringValue(ExternalData value) - : ExternalDataStringValue(0, std::move(value)) {} - - ExternalDataStringValue(size_t size, ExternalData value) - : StringValue(size), value_(std::move(value)) {} - - ExternalDataStringValue() = delete; - ExternalDataStringValue(const ExternalDataStringValue&) = delete; - ExternalDataStringValue(ExternalDataStringValue&&) = delete; - - // See comments for respective member functions on `StringValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - ExternalData value_; -}; - struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffsetBase { virtual ~CheckVptrOffsetBase() = default; @@ -282,12 +62,7 @@ union ValueHandleData final { // bits to differentiate between an inlined value (both 0), a heap allocated // reference counted value, or a arena allocated value. void* vptr; - std::aligned_union_t - padding; + alignas(std::max_align_t) char padding[32]; }; // Base implementation of persistent and transient handles for values. This @@ -333,12 +108,12 @@ class ValueHandleBase { // Called by `Transient` and `Persistent` to implement the same operator. friend bool operator==(const ValueHandleBase& lhs, const ValueHandleBase& rhs) { - const Value& lhs_value = ABSL_PREDICT_TRUE(static_cast(lhs)) - ? lhs.get() - : static_cast(NullValue::Get()); - const Value& rhs_value = ABSL_PREDICT_TRUE(static_cast(rhs)) - ? rhs.get() - : static_cast(NullValue::Get()); + if (static_cast(lhs) != static_cast(rhs) || + !static_cast(lhs)) { + return false; + } + const Value& lhs_value = lhs.get(); + const Value& rhs_value = rhs.get(); return lhs_value.Equals(rhs_value); } @@ -359,8 +134,6 @@ class ValueHandleBase { friend H AbslHashValue(H state, const ValueHandleBase& handle) { if (ABSL_PREDICT_TRUE(static_cast(handle))) { handle.get().HashValue(absl::HashState::Create(&state)); - } else { - NullValue::Get().HashValue(absl::HashState::Create(&state)); } return state; } @@ -536,27 +309,6 @@ struct HandleTraits; \ - extern template class Persistent -CEL_INTERNAL_VALUE_DECL(Value); -CEL_INTERNAL_VALUE_DECL(NullValue); -CEL_INTERNAL_VALUE_DECL(ErrorValue); -CEL_INTERNAL_VALUE_DECL(BoolValue); -CEL_INTERNAL_VALUE_DECL(IntValue); -CEL_INTERNAL_VALUE_DECL(UintValue); -CEL_INTERNAL_VALUE_DECL(DoubleValue); -CEL_INTERNAL_VALUE_DECL(BytesValue); -CEL_INTERNAL_VALUE_DECL(StringValue); -CEL_INTERNAL_VALUE_DECL(DurationValue); -CEL_INTERNAL_VALUE_DECL(TimestampValue); -CEL_INTERNAL_VALUE_DECL(EnumValue); -CEL_INTERNAL_VALUE_DECL(StructValue); -CEL_INTERNAL_VALUE_DECL(ListValue); -CEL_INTERNAL_VALUE_DECL(MapValue); -CEL_INTERNAL_VALUE_DECL(TypeValue); -#undef CEL_INTERNAL_VALUE_DECL - } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 4a2a41820..3cda23dc8 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -185,6 +185,22 @@ using BytesValueRep = } // namespace cel +#define CEL_INTERNAL_VALUE_DECL(name) \ + extern template class Persistent; \ + extern template class Persistent + +#define CEL_INTERNAL_VALUE_IMPL(name) \ + template class Persistent; \ + template class Persistent + +// Both are equivalent to std::construct_at implementation from C++20. +#define CEL_INTERNAL_VALUE_COPY_TO(type, src, dest) \ + ::new (const_cast( \ + static_cast(std::addressof(dest)))) type(src) +#define CEL_INTERNAL_VALUE_MOVE_TO(type, src, dest) \ + ::new (const_cast(static_cast( \ + std::addressof(dest)))) type(std::move(src)) + #define CEL_INTERNAL_DECLARE_VALUE(base, derived) \ private: \ friend class ::cel::base_internal::ValueHandleBase; \ diff --git a/base/type_test.cc b/base/type_test.cc index 50e0d122a..e057bd299 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -26,6 +26,8 @@ #include "base/type_factory.h" #include "base/type_manager.h" #include "base/value.h" +#include "base/values/enum_value.h" +#include "base/values/struct_value.h" #include "internal/testing.h" namespace cel { diff --git a/base/value.cc b/base/value.cc index 42b728371..add9bdaa7 100644 --- a/base/value.cc +++ b/base/value.cc @@ -14,73 +14,11 @@ #include "base/value.h" -#include -#include -#include -#include -#include -#include -#include +#include #include -#include "absl/base/attributes.h" -#include "absl/base/call_once.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "absl/container/btree_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/hash/hash.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/time/time.h" -#include "base/value_factory.h" -#include "internal/casts.h" -#include "internal/no_destructor.h" -#include "internal/status_macros.h" -#include "internal/strings.h" -#include "internal/time.h" -#include "internal/utf8.h" - namespace cel { -#define CEL_INTERNAL_VALUE_IMPL(name) \ - template class Persistent; \ - template class Persistent -CEL_INTERNAL_VALUE_IMPL(Value); -CEL_INTERNAL_VALUE_IMPL(NullValue); -CEL_INTERNAL_VALUE_IMPL(ErrorValue); -CEL_INTERNAL_VALUE_IMPL(BoolValue); -CEL_INTERNAL_VALUE_IMPL(IntValue); -CEL_INTERNAL_VALUE_IMPL(UintValue); -CEL_INTERNAL_VALUE_IMPL(DoubleValue); -CEL_INTERNAL_VALUE_IMPL(BytesValue); -CEL_INTERNAL_VALUE_IMPL(StringValue); -CEL_INTERNAL_VALUE_IMPL(DurationValue); -CEL_INTERNAL_VALUE_IMPL(TimestampValue); -CEL_INTERNAL_VALUE_IMPL(EnumValue); -CEL_INTERNAL_VALUE_IMPL(StructValue); -CEL_INTERNAL_VALUE_IMPL(ListValue); -CEL_INTERNAL_VALUE_IMPL(MapValue); -CEL_INTERNAL_VALUE_IMPL(TypeValue); -#undef CEL_INTERNAL_VALUE_IMPL - -namespace { - -using base_internal::PersistentHandleFactory; - -// Both are equivalent to std::construct_at implementation from C++20. -#define CEL_COPY_TO_IMPL(type, src, dest) \ - ::new (const_cast( \ - static_cast(std::addressof(dest)))) type(src) -#define CEL_MOVE_TO_IMPL(type, src, dest) \ - ::new (const_cast(static_cast( \ - std::addressof(dest)))) type(std::move(src)) - -} // namespace - std::pair Value::SizeAndAlignment() const { // Currently most implementations of Value are not reference counted, so those // that are override this and those that do not inherit this. Using 0 here @@ -92,958 +30,4 @@ void Value::CopyTo(Value& address) const {} void Value::MoveTo(Value& address) {} -Persistent NullValue::Get(ValueFactory& value_factory) { - return value_factory.GetNullValue(); -} - -Persistent NullValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - NullType::Get()); -} - -std::string NullValue::DebugString() const { return "null"; } - -const NullValue& NullValue::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -void NullValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(NullValue, *this, address); -} - -void NullValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(NullValue, *this, address); -} - -bool NullValue::Equals(const Value& other) const { - return kind() == other.kind(); -} - -void NullValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), 0); -} - -namespace { - -struct StatusPayload final { - std::string key; - absl::Cord value; -}; - -void StatusHashValue(absl::HashState state, const absl::Status& status) { - // absl::Status::operator== compares `raw_code()`, `message()` and the - // payloads. - state = absl::HashState::combine(std::move(state), status.raw_code(), - status.message()); - // In order to determistically hash, we need to put the payloads in sorted - // order. There is no guarantee from `absl::Status` on the order of the - // payloads returned from `absl::Status::ForEachPayload`. - // - // This should be the same inline size as - // `absl::status_internal::StatusPayloads`. - absl::InlinedVector payloads; - status.ForEachPayload([&](absl::string_view key, const absl::Cord& value) { - payloads.push_back(StatusPayload{std::string(key), value}); - }); - std::stable_sort( - payloads.begin(), payloads.end(), - [](const StatusPayload& lhs, const StatusPayload& rhs) -> bool { - return lhs.key < rhs.key; - }); - for (const auto& payload : payloads) { - state = - absl::HashState::combine(std::move(state), payload.key, payload.value); - } -} - -} // namespace - -Persistent ErrorValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - ErrorType::Get()); -} - -std::string ErrorValue::DebugString() const { return value().ToString(); } - -void ErrorValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(ErrorValue, *this, address); -} - -void ErrorValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(ErrorValue, *this, address); -} - -bool ErrorValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void ErrorValue::HashValue(absl::HashState state) const { - StatusHashValue(absl::HashState::combine(std::move(state), type()), value()); -} - -Persistent BoolValue::False(ValueFactory& value_factory) { - return value_factory.CreateBoolValue(false); -} - -Persistent BoolValue::True(ValueFactory& value_factory) { - return value_factory.CreateBoolValue(true); -} - -Persistent BoolValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - BoolType::Get()); -} - -std::string BoolValue::DebugString() const { - return value() ? "true" : "false"; -} - -void BoolValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(BoolValue, *this, address); -} - -void BoolValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(BoolValue, *this, address); -} - -bool BoolValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void BoolValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -Persistent IntValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - IntType::Get()); -} - -std::string IntValue::DebugString() const { return absl::StrCat(value()); } - -void IntValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(IntValue, *this, address); -} - -void IntValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(IntValue, *this, address); -} - -bool IntValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void IntValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -Persistent UintValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - UintType::Get()); -} - -std::string UintValue::DebugString() const { - return absl::StrCat(value(), "u"); -} - -void UintValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(UintValue, *this, address); -} - -void UintValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(UintValue, *this, address); -} - -bool UintValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void UintValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -Persistent DoubleValue::NaN(ValueFactory& value_factory) { - return value_factory.CreateDoubleValue( - std::numeric_limits::quiet_NaN()); -} - -Persistent DoubleValue::PositiveInfinity( - ValueFactory& value_factory) { - return value_factory.CreateDoubleValue( - std::numeric_limits::infinity()); -} - -Persistent DoubleValue::NegativeInfinity( - ValueFactory& value_factory) { - return value_factory.CreateDoubleValue( - -std::numeric_limits::infinity()); -} - -Persistent DoubleValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - DoubleType::Get()); -} - -std::string DoubleValue::DebugString() const { - if (std::isfinite(value())) { - if (std::floor(value()) != value()) { - // The double is not representable as a whole number, so use - // absl::StrCat which will add decimal places. - return absl::StrCat(value()); - } - // absl::StrCat historically would represent 0.0 as 0, and we want the - // decimal places so ZetaSQL correctly assumes the type as double - // instead of int64_t. - std::string stringified = absl::StrCat(value()); - if (!absl::StrContains(stringified, '.')) { - absl::StrAppend(&stringified, ".0"); - } else { - // absl::StrCat has a decimal now? Use it directly. - } - return stringified; - } - if (std::isnan(value())) { - return "nan"; - } - if (std::signbit(value())) { - return "-infinity"; - } - return "+infinity"; -} - -void DoubleValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(DoubleValue, *this, address); -} - -void DoubleValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(DoubleValue, *this, address); -} - -bool DoubleValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void DoubleValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -Persistent DurationValue::Zero( - ValueFactory& value_factory) { - // Should never fail, tests assert this. - return value_factory.CreateDurationValue(absl::ZeroDuration()).value(); -} - -Persistent DurationValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - DurationType::Get()); -} - -std::string DurationValue::DebugString() const { - return internal::FormatDuration(value()).value(); -} - -void DurationValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(DurationValue, *this, address); -} - -void DurationValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(DurationValue, *this, address); -} - -bool DurationValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void DurationValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -Persistent TimestampValue::UnixEpoch( - ValueFactory& value_factory) { - // Should never fail, tests assert this. - return value_factory.CreateTimestampValue(absl::UnixEpoch()).value(); -} - -Persistent TimestampValue::type() const { - return PersistentHandleFactory::MakeUnmanaged< - const TimestampType>(TimestampType::Get()); -} - -std::string TimestampValue::DebugString() const { - return internal::FormatTimestamp(value()).value(); -} - -void TimestampValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(TimestampValue, *this, address); -} - -void TimestampValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(TimestampValue, *this, address); -} - -bool TimestampValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void TimestampValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -namespace { - -struct BytesValueDebugStringVisitor final { - std::string operator()(absl::string_view value) const { - return internal::FormatBytesLiteral(value); - } - - std::string operator()(const absl::Cord& value) const { - return internal::FormatBytesLiteral(static_cast(value)); - } -}; - -struct StringValueDebugStringVisitor final { - std::string operator()(absl::string_view value) const { - return internal::FormatStringLiteral(value); - } - - std::string operator()(const absl::Cord& value) const { - return internal::FormatStringLiteral(static_cast(value)); - } -}; - -struct ToStringVisitor final { - std::string operator()(absl::string_view value) const { - return std::string(value); - } - - std::string operator()(const absl::Cord& value) const { - return static_cast(value); - } -}; - -struct BytesValueSizeVisitor final { - size_t operator()(absl::string_view value) const { return value.size(); } - - size_t operator()(const absl::Cord& value) const { return value.size(); } -}; - -struct StringValueSizeVisitor final { - size_t operator()(absl::string_view value) const { - return internal::Utf8CodePointCount(value); - } - - size_t operator()(const absl::Cord& value) const { - return internal::Utf8CodePointCount(value); - } -}; - -struct EmptyVisitor final { - bool operator()(absl::string_view value) const { return value.empty(); } - - bool operator()(const absl::Cord& value) const { return value.empty(); } -}; - -bool EqualsImpl(absl::string_view lhs, absl::string_view rhs) { - return lhs == rhs; -} - -bool EqualsImpl(absl::string_view lhs, const absl::Cord& rhs) { - return lhs == rhs; -} - -bool EqualsImpl(const absl::Cord& lhs, absl::string_view rhs) { - return lhs == rhs; -} - -bool EqualsImpl(const absl::Cord& lhs, const absl::Cord& rhs) { - return lhs == rhs; -} - -int CompareImpl(absl::string_view lhs, absl::string_view rhs) { - return lhs.compare(rhs); -} - -int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { - return -rhs.Compare(lhs); -} - -int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { - return lhs.Compare(rhs); -} - -int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { - return lhs.Compare(rhs); -} - -template -class EqualsVisitor final { - public: - explicit EqualsVisitor(const T& ref) : ref_(ref) {} - - bool operator()(absl::string_view value) const { - return EqualsImpl(value, ref_); - } - - bool operator()(const absl::Cord& value) const { - return EqualsImpl(value, ref_); - } - - private: - const T& ref_; -}; - -template <> -class EqualsVisitor final { - public: - explicit EqualsVisitor(const BytesValue& ref) : ref_(ref) {} - - bool operator()(absl::string_view value) const { return ref_.Equals(value); } - - bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } - - private: - const BytesValue& ref_; -}; - -template <> -class EqualsVisitor final { - public: - explicit EqualsVisitor(const StringValue& ref) : ref_(ref) {} - - bool operator()(absl::string_view value) const { return ref_.Equals(value); } - - bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } - - private: - const StringValue& ref_; -}; - -template -class CompareVisitor final { - public: - explicit CompareVisitor(const T& ref) : ref_(ref) {} - - int operator()(absl::string_view value) const { - return CompareImpl(value, ref_); - } - - int operator()(const absl::Cord& value) const { - return CompareImpl(value, ref_); - } - - private: - const T& ref_; -}; - -template <> -class CompareVisitor final { - public: - explicit CompareVisitor(const BytesValue& ref) : ref_(ref) {} - - int operator()(const absl::Cord& value) const { return ref_.Compare(value); } - - int operator()(absl::string_view value) const { return ref_.Compare(value); } - - private: - const BytesValue& ref_; -}; - -template <> -class CompareVisitor final { - public: - explicit CompareVisitor(const StringValue& ref) : ref_(ref) {} - - int operator()(const absl::Cord& value) const { return ref_.Compare(value); } - - int operator()(absl::string_view value) const { return ref_.Compare(value); } - - private: - const StringValue& ref_; -}; - -class HashValueVisitor final { - public: - explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} - - void operator()(absl::string_view value) { - absl::HashState::combine(std::move(state_), value); - } - - void operator()(const absl::Cord& value) { - absl::HashState::combine(std::move(state_), value); - } - - private: - absl::HashState state_; -}; - -template -bool CanPerformZeroCopy(MemoryManager& memory_manager, - const Persistent& handle) { - return base_internal::IsManagedHandle(handle) && - std::addressof(memory_manager) == - std::addressof(base_internal::GetMemoryManager(handle)); -} - -} // namespace - -Persistent BytesValue::Empty(ValueFactory& value_factory) { - return value_factory.GetBytesValue(); -} - -absl::StatusOr> BytesValue::Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs) { - absl::Cord cord; - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); - return value_factory.CreateBytesValue(std::move(cord)); -} - -Persistent BytesValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - BytesType::Get()); -} - -size_t BytesValue::size() const { - return absl::visit(BytesValueSizeVisitor{}, rep()); -} - -bool BytesValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } - -bool BytesValue::Equals(absl::string_view bytes) const { - return absl::visit(EqualsVisitor(bytes), rep()); -} - -bool BytesValue::Equals(const absl::Cord& bytes) const { - return absl::visit(EqualsVisitor(bytes), rep()); -} - -bool BytesValue::Equals(const Persistent& bytes) const { - return absl::visit(EqualsVisitor(*this), bytes->rep()); -} - -int BytesValue::Compare(absl::string_view bytes) const { - return absl::visit(CompareVisitor(bytes), rep()); -} - -int BytesValue::Compare(const absl::Cord& bytes) const { - return absl::visit(CompareVisitor(bytes), rep()); -} - -int BytesValue::Compare(const Persistent& bytes) const { - return absl::visit(CompareVisitor(*this), bytes->rep()); -} - -std::string BytesValue::ToString() const { - return absl::visit(ToStringVisitor{}, rep()); -} - -std::string BytesValue::DebugString() const { - return absl::visit(BytesValueDebugStringVisitor{}, rep()); -} - -bool BytesValue::Equals(const Value& other) const { - return kind() == other.kind() && - absl::visit(EqualsVisitor(*this), - internal::down_cast(other).rep()); -} - -void BytesValue::HashValue(absl::HashState state) const { - absl::visit( - HashValueVisitor(absl::HashState::combine(std::move(state), type())), - rep()); -} - -Persistent StringValue::Empty(ValueFactory& value_factory) { - return value_factory.GetStringValue(); -} - -absl::StatusOr> StringValue::Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs) { - absl::Cord cord; - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); - size_t size = 0; - size_t lhs_size = lhs->size_.load(std::memory_order_relaxed); - if (lhs_size != 0 && !lhs->empty()) { - size_t rhs_size = rhs->size_.load(std::memory_order_relaxed); - if (rhs_size != 0 && !rhs->empty()) { - size = lhs_size + rhs_size; - } - } - return value_factory.CreateStringValue(std::move(cord), size); -} - -Persistent StringValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - StringType::Get()); -} - -size_t StringValue::size() const { - // We lazily calculate the code point count in some circumstances. If the code - // point count is 0 and the underlying rep is not empty we need to actually - // calculate the size. It is okay if this is done by multiple threads - // simultaneously, it is a benign race. - size_t size = size_.load(std::memory_order_relaxed); - if (size == 0 && !empty()) { - size = absl::visit(StringValueSizeVisitor{}, rep()); - size_.store(size, std::memory_order_relaxed); - } - return size; -} - -bool StringValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } - -bool StringValue::Equals(absl::string_view string) const { - return absl::visit(EqualsVisitor(string), rep()); -} - -bool StringValue::Equals(const absl::Cord& string) const { - return absl::visit(EqualsVisitor(string), rep()); -} - -bool StringValue::Equals(const Persistent& string) const { - return absl::visit(EqualsVisitor(*this), string->rep()); -} - -int StringValue::Compare(absl::string_view string) const { - return absl::visit(CompareVisitor(string), rep()); -} - -int StringValue::Compare(const absl::Cord& string) const { - return absl::visit(CompareVisitor(string), rep()); -} - -int StringValue::Compare(const Persistent& string) const { - return absl::visit(CompareVisitor(*this), string->rep()); -} - -std::string StringValue::ToString() const { - return absl::visit(ToStringVisitor{}, rep()); -} - -std::string StringValue::DebugString() const { - return absl::visit(StringValueDebugStringVisitor{}, rep()); -} - -bool StringValue::Equals(const Value& other) const { - return kind() == other.kind() && - absl::visit(EqualsVisitor(*this), - internal::down_cast(other).rep()); -} - -void StringValue::HashValue(absl::HashState state) const { - absl::visit( - HashValueVisitor(absl::HashState::combine(std::move(state), type())), - rep()); -} - -struct EnumType::NewInstanceVisitor final { - const Persistent& enum_type; - ValueFactory& value_factory; - - absl::StatusOr> operator()( - absl::string_view name) const { - TypedEnumValueFactory factory(value_factory, enum_type); - return enum_type->NewInstanceByName(factory, name); - } - - absl::StatusOr> operator()(int64_t number) const { - TypedEnumValueFactory factory(value_factory, enum_type); - return enum_type->NewInstanceByNumber(factory, number); - } -}; - -absl::StatusOr> EnumValue::New( - const Persistent& enum_type, ValueFactory& value_factory, - EnumType::ConstantId id) { - CEL_ASSIGN_OR_RETURN( - auto enum_value, - absl::visit(EnumType::NewInstanceVisitor{enum_type, value_factory}, - id.data_)); - if (!enum_value->type_) { - // In case somebody is caching, we avoid setting the type_ if it has already - // been set, to avoid a race condition where one CPU sees a half written - // pointer. - const_cast(*enum_value).type_ = enum_type; - } - return enum_value; -} - -bool EnumValue::Equals(const Value& other) const { - return kind() == other.kind() && type() == other.type() && - number() == internal::down_cast(other).number(); -} - -void EnumValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), number()); -} - -struct StructValue::SetFieldVisitor final { - StructValue& struct_value; - const Persistent& value; - - absl::Status operator()(absl::string_view name) const { - return struct_value.SetFieldByName(name, value); - } - - absl::Status operator()(int64_t number) const { - return struct_value.SetFieldByNumber(number, value); - } -}; - -struct StructValue::GetFieldVisitor final { - const StructValue& struct_value; - ValueFactory& value_factory; - - absl::StatusOr> operator()( - absl::string_view name) const { - return struct_value.GetFieldByName(value_factory, name); - } - - absl::StatusOr> operator()(int64_t number) const { - return struct_value.GetFieldByNumber(value_factory, number); - } -}; - -struct StructValue::HasFieldVisitor final { - const StructValue& struct_value; - - absl::StatusOr operator()(absl::string_view name) const { - return struct_value.HasFieldByName(name); - } - - absl::StatusOr operator()(int64_t number) const { - return struct_value.HasFieldByNumber(number); - } -}; - -absl::StatusOr> StructValue::New( - const Persistent& struct_type, - ValueFactory& value_factory) { - TypedStructValueFactory factory(value_factory, struct_type); - CEL_ASSIGN_OR_RETURN(auto struct_value, struct_type->NewInstance(factory)); - if (!struct_value->type_) { - // In case somebody is caching, we avoid setting the type_ if it has already - // been set, to avoid a race condition where one CPU sees a half written - // pointer. - const_cast(*struct_value).type_ = struct_type; - } - return struct_value; -} - -absl::Status StructValue::SetField(FieldId field, - const Persistent& value) { - return absl::visit(SetFieldVisitor{*this, value}, field.data_); -} - -absl::StatusOr> StructValue::GetField( - ValueFactory& value_factory, FieldId field) const { - return absl::visit(GetFieldVisitor{*this, value_factory}, field.data_); -} - -absl::StatusOr StructValue::HasField(FieldId field) const { - return absl::visit(HasFieldVisitor{*this}, field.data_); -} - -Persistent TypeValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - TypeType::Get()); -} - -std::string TypeValue::DebugString() const { return value()->DebugString(); } - -void TypeValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(TypeValue, *this, address); -} - -void TypeValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(TypeValue, *this, address); -} - -bool TypeValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void TypeValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - -namespace base_internal { - -absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return value_; -} - -void InlinedCordBytesValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(InlinedCordBytesValue, *this, address); -} - -void InlinedCordBytesValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(InlinedCordBytesValue, *this, address); -} - -typename InlinedCordBytesValue::Rep InlinedCordBytesValue::rep() const { - return Rep(absl::in_place_type>, - std::cref(value_)); -} - -absl::Cord InlinedStringViewBytesValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return absl::Cord(value_); -} - -void InlinedStringViewBytesValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(InlinedStringViewBytesValue, *this, address); -} - -void InlinedStringViewBytesValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(InlinedStringViewBytesValue, *this, address); -} - -typename InlinedStringViewBytesValue::Rep InlinedStringViewBytesValue::rep() - const { - return Rep(absl::in_place_type, value_); -} - -std::pair StringBytesValue::SizeAndAlignment() const { - return std::make_pair(sizeof(StringBytesValue), alignof(StringBytesValue)); -} - -absl::Cord StringBytesValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal(absl::string_view(value_), - [this]() { Unref(); }); - } - return absl::Cord(value_); -} - -typename StringBytesValue::Rep StringBytesValue::rep() const { - return Rep(absl::in_place_type, absl::string_view(value_)); -} - -std::pair ExternalDataBytesValue::SizeAndAlignment() const { - return std::make_pair(sizeof(ExternalDataBytesValue), - alignof(ExternalDataBytesValue)); -} - -absl::Cord ExternalDataBytesValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal( - absl::string_view(static_cast(value_.data), value_.size), - [this]() { Unref(); }); - } - return absl::Cord( - absl::string_view(static_cast(value_.data), value_.size)); -} - -typename ExternalDataBytesValue::Rep ExternalDataBytesValue::rep() const { - return Rep( - absl::in_place_type, - absl::string_view(static_cast(value_.data), value_.size)); -} - -absl::Cord InlinedCordStringValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return value_; -} - -void InlinedCordStringValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(InlinedCordStringValue, *this, address); -} - -void InlinedCordStringValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(InlinedCordStringValue, *this, address); -} - -typename InlinedCordStringValue::Rep InlinedCordStringValue::rep() const { - return Rep(absl::in_place_type>, - std::cref(value_)); -} - -absl::Cord InlinedStringViewStringValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return absl::Cord(value_); -} - -void InlinedStringViewStringValue::CopyTo(Value& address) const { - CEL_COPY_TO_IMPL(InlinedStringViewStringValue, *this, address); -} - -void InlinedStringViewStringValue::MoveTo(Value& address) { - CEL_MOVE_TO_IMPL(InlinedStringViewStringValue, *this, address); -} - -typename InlinedStringViewStringValue::Rep InlinedStringViewStringValue::rep() - const { - return Rep(absl::in_place_type, value_); -} - -std::pair StringStringValue::SizeAndAlignment() const { - return std::make_pair(sizeof(StringStringValue), alignof(StringStringValue)); -} - -absl::Cord StringStringValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal(absl::string_view(value_), - [this]() { Unref(); }); - } - return absl::Cord(value_); -} - -typename StringStringValue::Rep StringStringValue::rep() const { - return Rep(absl::in_place_type, absl::string_view(value_)); -} - -std::pair ExternalDataStringValue::SizeAndAlignment() const { - return std::make_pair(sizeof(ExternalDataStringValue), - alignof(ExternalDataStringValue)); -} - -absl::Cord ExternalDataStringValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal( - absl::string_view(static_cast(value_.data), value_.size), - [this]() { Unref(); }); - } - return absl::Cord( - absl::string_view(static_cast(value_.data), value_.size)); -} - -typename ExternalDataStringValue::Rep ExternalDataStringValue::rep() const { - return Rep( - absl::in_place_type, - absl::string_view(static_cast(value_.data), value_.size)); -} - -} // namespace base_internal - } // namespace cel diff --git a/base/value.h b/base/value.h index 6cc73e78a..04c678de1 100644 --- a/base/value.h +++ b/base/value.h @@ -21,13 +21,7 @@ #include #include -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/cord.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/handle.h" @@ -35,25 +29,6 @@ #include "base/kind.h" #include "base/memory_manager.h" #include "base/type.h" -#include "base/types/any_type.h" -#include "base/types/bool_type.h" -#include "base/types/bytes_type.h" -#include "base/types/double_type.h" -#include "base/types/duration_type.h" -#include "base/types/dyn_type.h" -#include "base/types/enum_type.h" -#include "base/types/error_type.h" -#include "base/types/int_type.h" -#include "base/types/list_type.h" -#include "base/types/map_type.h" -#include "base/types/null_type.h" -#include "base/types/string_type.h" -#include "base/types/struct_type.h" -#include "base/types/timestamp_type.h" -#include "base/types/type_type.h" -#include "base/types/uint_type.h" -#include "internal/casts.h" -#include "internal/rtti.h" namespace cel { @@ -154,813 +129,6 @@ class Value : public base_internal::Resource { virtual void MoveTo(Value& address); }; -class NullValue final : public Value, public base_internal::ResourceInlined { - public: - static Persistent Get(ValueFactory& value_factory); - - Persistent type() const override; - - Kind kind() const override { return Kind::kNullType; } - - std::string DebugString() const override; - - // Note GCC does not consider a friend member as a member of a friend. - ABSL_ATTRIBUTE_PURE_FUNCTION static const NullValue& Get(); - - bool Equals(const Value& other) const override; - - void HashValue(absl::HashState state) const override; - - private: - friend class ValueFactory; - template - friend class internal::NoDestructor; - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kNullType; } - - NullValue() = default; - NullValue(const NullValue&) = default; - NullValue(NullValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; -}; - -class ErrorValue final : public Value, public base_internal::ResourceInlined { - public: - Persistent type() const override; - - Kind kind() const override { return Kind::kError; } - - std::string DebugString() const override; - - const absl::Status& value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kError; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit ErrorValue(absl::Status value) : value_(std::move(value)) {} - - ErrorValue() = delete; - - ErrorValue(const ErrorValue&) = default; - ErrorValue(ErrorValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - absl::Status value_; -}; - -class BoolValue final : public Value, public base_internal::ResourceInlined { - public: - static Persistent False(ValueFactory& value_factory); - - static Persistent True(ValueFactory& value_factory); - - Persistent type() const override; - - Kind kind() const override { return Kind::kBool; } - - std::string DebugString() const override; - - constexpr bool value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kBool; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit BoolValue(bool value) : value_(value) {} - - BoolValue() = delete; - - BoolValue(const BoolValue&) = default; - BoolValue(BoolValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - bool value_; -}; - -class IntValue final : public Value, public base_internal::ResourceInlined { - public: - Persistent type() const override; - - Kind kind() const override { return Kind::kInt; } - - std::string DebugString() const override; - - constexpr int64_t value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kInt; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit IntValue(int64_t value) : value_(value) {} - - IntValue() = delete; - - IntValue(const IntValue&) = default; - IntValue(IntValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - int64_t value_; -}; - -class UintValue final : public Value, public base_internal::ResourceInlined { - public: - Persistent type() const override; - - Kind kind() const override { return Kind::kUint; } - - std::string DebugString() const override; - - constexpr uint64_t value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kUint; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit UintValue(uint64_t value) : value_(value) {} - - UintValue() = delete; - - UintValue(const UintValue&) = default; - UintValue(UintValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - uint64_t value_; -}; - -class DoubleValue final : public Value, public base_internal::ResourceInlined { - public: - static Persistent NaN(ValueFactory& value_factory); - - static Persistent PositiveInfinity( - ValueFactory& value_factory); - - static Persistent NegativeInfinity( - ValueFactory& value_factory); - - Persistent type() const override; - - Kind kind() const override { return Kind::kDouble; } - - std::string DebugString() const override; - - constexpr double value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kDouble; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit DoubleValue(double value) : value_(value) {} - - DoubleValue() = delete; - - DoubleValue(const DoubleValue&) = default; - DoubleValue(DoubleValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - double value_; -}; - -class BytesValue : public Value { - protected: - using Rep = base_internal::BytesValueRep; - - public: - static Persistent Empty(ValueFactory& value_factory); - - // Concat concatenates the contents of two ByteValue, returning a new - // ByteValue. The resulting ByteValue is not tied to the lifetime of either of - // the input ByteValue. - static absl::StatusOr> Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs); - - Persistent type() const final; - - Kind kind() const final { return Kind::kBytes; } - - std::string DebugString() const final; - - size_t size() const; - - bool empty() const; - - bool Equals(absl::string_view bytes) const; - bool Equals(const absl::Cord& bytes) const; - bool Equals(const Persistent& bytes) const; - - int Compare(absl::string_view bytes) const; - int Compare(const absl::Cord& bytes) const; - int Compare(const Persistent& bytes) const; - - std::string ToString() const; - - absl::Cord ToCord() const { - // Without the handle we cannot know if this is reference counted. - return ToCord(/*reference_counted=*/false); - } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - friend class base_internal::InlinedCordBytesValue; - friend class base_internal::InlinedStringViewBytesValue; - friend class base_internal::StringBytesValue; - friend class base_internal::ExternalDataBytesValue; - friend base_internal::BytesValueRep interop_internal::GetBytesValueRep( - const Persistent& value); - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kBytes; } - - BytesValue() = default; - BytesValue(const BytesValue&) = default; - BytesValue(BytesValue&&) = default; - - // Get the contents of this BytesValue as absl::Cord. When reference_counted - // is true, the implementation can potentially return an absl::Cord that wraps - // the contents instead of copying. - virtual absl::Cord ToCord(bool reference_counted) const = 0; - - // Get the contents of this BytesValue as either absl::string_view or const - // absl::Cord&. - virtual Rep rep() const = 0; - - // See comments for respective member functions on `Value`. - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; -}; - -class StringValue : public Value { - protected: - using Rep = base_internal::StringValueRep; - - public: - static Persistent Empty(ValueFactory& value_factory); - - static absl::StatusOr> Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs); - - Persistent type() const final; - - Kind kind() const final { return Kind::kString; } - - std::string DebugString() const final; - - size_t size() const; - - bool empty() const; - - bool Equals(absl::string_view string) const; - bool Equals(const absl::Cord& string) const; - bool Equals(const Persistent& string) const; - - int Compare(absl::string_view string) const; - int Compare(const absl::Cord& string) const; - int Compare(const Persistent& string) const; - - std::string ToString() const; - - absl::Cord ToCord() const { - // Without the handle we cannot know if this is reference counted. - return ToCord(/*reference_counted=*/false); - } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - friend class base_internal::InlinedCordStringValue; - friend class base_internal::InlinedStringViewStringValue; - friend class base_internal::StringStringValue; - friend class base_internal::ExternalDataStringValue; - friend base_internal::StringValueRep interop_internal::GetStringValueRep( - const Persistent& value); - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kString; } - - explicit StringValue(size_t size) : size_(size) {} - - StringValue() = default; - - StringValue(const StringValue& other) - : StringValue(other.size_.load(std::memory_order_relaxed)) {} - - StringValue(StringValue&& other) - : StringValue(other.size_.exchange(0, std::memory_order_relaxed)) {} - - // Get the contents of this BytesValue as absl::Cord. When reference_counted - // is true, the implementation can potentially return an absl::Cord that wraps - // the contents instead of copying. - virtual absl::Cord ToCord(bool reference_counted) const = 0; - - // Get the contents of this StringValue as either absl::string_view or const - // absl::Cord&. - virtual Rep rep() const = 0; - - // See comments for respective member functions on `Value`. - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; - - // Lazily cached code point count. - mutable std::atomic size_ = 0; -}; - -class DurationValue final : public Value, - public base_internal::ResourceInlined { - public: - static Persistent Zero(ValueFactory& value_factory); - - Persistent type() const override; - - Kind kind() const override { return Kind::kDuration; } - - std::string DebugString() const override; - - constexpr absl::Duration value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kDuration; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit DurationValue(absl::Duration value) : value_(value) {} - - DurationValue() = delete; - - DurationValue(const DurationValue&) = default; - DurationValue(DurationValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - absl::Duration value_; -}; - -class TimestampValue final : public Value, - public base_internal::ResourceInlined { - public: - static Persistent UnixEpoch( - ValueFactory& value_factory); - - Persistent type() const override; - - Kind kind() const override { return Kind::kTimestamp; } - - std::string DebugString() const override; - - constexpr absl::Time value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { - return value.kind() == Kind::kTimestamp; - } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit TimestampValue(absl::Time value) : value_(value) {} - - TimestampValue() = delete; - - TimestampValue(const TimestampValue&) = default; - TimestampValue(TimestampValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - absl::Time value_; -}; - -// EnumValue represents a single constant belonging to cel::EnumType. -class EnumValue : public Value { - public: - static absl::StatusOr> New( - const Persistent& enum_type, ValueFactory& value_factory, - EnumType::ConstantId id); - - Persistent type() const final { return type_; } - - Kind kind() const final { return Kind::kEnum; } - - virtual int64_t number() const = 0; - - virtual absl::string_view name() const = 0; - - protected: - explicit EnumValue(const Persistent& type) : type_(type) { - ABSL_ASSERT(type_); - } - - private: - friend internal::TypeInfo base_internal::GetEnumValueTypeId( - const EnumValue& enum_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kEnum; } - - EnumValue(const EnumValue&) = delete; - EnumValue(EnumValue&&) = delete; - - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; - - Persistent type_; -}; - -// CEL_DECLARE_ENUM_VALUE declares `enum_value` as an enumeration value. It must -// be part of the class definition of `enum_value`. -// -// class MyEnumValue : public cel::EnumValue { -// ... -// private: -// CEL_DECLARE_ENUM_VALUE(MyEnumValue); -// }; -#define CEL_DECLARE_ENUM_VALUE(enum_value) \ - CEL_INTERNAL_DECLARE_VALUE(Enum, enum_value) - -// CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It -// must be called after the class definition of `enum_value`. -// -// class MyEnumValue : public cel::EnumValue { -// ... -// private: -// CEL_DECLARE_ENUM_VALUE(MyEnumValue); -// }; -// -// CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); -#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ - CEL_INTERNAL_IMPLEMENT_VALUE(Enum, enum_value) - -// StructValue represents an instance of cel::StructType. -class StructValue : public Value { - public: - using FieldId = StructType::FieldId; - - static absl::StatusOr> New( - const Persistent& struct_type, - ValueFactory& value_factory); - - Persistent type() const final { return type_; } - - Kind kind() const final { return Kind::kStruct; } - - absl::Status SetField(FieldId field, const Persistent& value); - - absl::StatusOr> GetField(ValueFactory& value_factory, - FieldId field) const; - - absl::StatusOr HasField(FieldId field) const; - - protected: - explicit StructValue(const Persistent& type) : type_(type) { - ABSL_ASSERT(type_); - } - - virtual absl::Status SetFieldByName(absl::string_view name, - const Persistent& value) = 0; - - virtual absl::Status SetFieldByNumber( - int64_t number, const Persistent& value) = 0; - - virtual absl::StatusOr> GetFieldByName( - ValueFactory& value_factory, absl::string_view name) const = 0; - - virtual absl::StatusOr> GetFieldByNumber( - ValueFactory& value_factory, int64_t number) const = 0; - - virtual absl::StatusOr HasFieldByName(absl::string_view name) const = 0; - - virtual absl::StatusOr HasFieldByNumber(int64_t number) const = 0; - - private: - struct SetFieldVisitor; - struct GetFieldVisitor; - struct HasFieldVisitor; - - friend struct SetFieldVisitor; - friend struct GetFieldVisitor; - friend struct HasFieldVisitor; - friend internal::TypeInfo base_internal::GetStructValueTypeId( - const StructValue& struct_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kStruct; } - - StructValue(const StructValue&) = delete; - StructValue(StructValue&&) = delete; - - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_STRUCT_VALUE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; - - Persistent type_; -}; - -// CEL_DECLARE_STRUCT_VALUE declares `struct_value` as an struct value. It must -// be part of the class definition of `struct_value`. -// -// class MyStructValue : public cel::StructValue { -// ... -// private: -// CEL_DECLARE_STRUCT_VALUE(MyStructValue); -// }; -#define CEL_DECLARE_STRUCT_VALUE(struct_value) \ - CEL_INTERNAL_DECLARE_VALUE(Struct, struct_value) - -// CEL_IMPLEMENT_STRUCT_VALUE implements `struct_value` as an struct -// value. It must be called after the class definition of `struct_value`. -// -// class MyStructValue : public cel::StructValue { -// ... -// private: -// CEL_DECLARE_STRUCT_VALUE(MyStructValue); -// }; -// -// CEL_IMPLEMENT_STRUCT_VALUE(MyStructValue); -#define CEL_IMPLEMENT_STRUCT_VALUE(struct_value) \ - CEL_INTERNAL_IMPLEMENT_VALUE(Struct, struct_value) - -// ListValue represents an instance of cel::ListType. -class ListValue : public Value { - public: - // TODO(issues/5): implement iterators so we can have cheap concated lists - - Persistent type() const final { return type_; } - - Kind kind() const final { return Kind::kList; } - - virtual size_t size() const = 0; - - virtual bool empty() const { return size() == 0; } - - virtual absl::StatusOr> Get( - ValueFactory& value_factory, size_t index) const = 0; - - protected: - explicit ListValue(const Persistent& type) : type_(type) {} - - private: - friend internal::TypeInfo base_internal::GetListValueTypeId( - const ListValue& list_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kList; } - - ListValue(const ListValue&) = delete; - ListValue(ListValue&&) = delete; - - // TODO(issues/5): I do not like this, we should have these two take a - // ValueFactory and return absl::StatusOr and absl::Status. We support - // lazily created values, so errors can occur during equality testing. - // Especially if there are different value implementations for the same type. - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; - - const Persistent type_; -}; - -// CEL_DECLARE_LIST_VALUE declares `list_value` as an list value. It must -// be part of the class definition of `list_value`. -// -// class MyListValue : public cel::ListValue { -// ... -// private: -// CEL_DECLARE_LIST_VALUE(MyListValue); -// }; -#define CEL_DECLARE_LIST_VALUE(list_value) \ - CEL_INTERNAL_DECLARE_VALUE(List, list_value) - -// CEL_IMPLEMENT_LIST_VALUE implements `list_value` as an list -// value. It must be called after the class definition of `list_value`. -// -// class MyListValue : public cel::ListValue { -// ... -// private: -// CEL_DECLARE_LIST_VALUE(MyListValue); -// }; -// -// CEL_IMPLEMENT_LIST_VALUE(MyListValue); -#define CEL_IMPLEMENT_LIST_VALUE(list_value) \ - CEL_INTERNAL_IMPLEMENT_VALUE(List, list_value) - -// MapValue represents an instance of cel::MapType. -class MapValue : public Value { - public: - Persistent type() const final { return type_; } - - Kind kind() const final { return Kind::kMap; } - - virtual size_t size() const = 0; - - virtual bool empty() const { return size() == 0; } - - virtual absl::StatusOr> Get( - ValueFactory& value_factory, - const Persistent& key) const = 0; - - virtual absl::StatusOr Has( - const Persistent& key) const = 0; - - protected: - explicit MapValue(const Persistent& type) : type_(type) {} - - private: - friend internal::TypeInfo base_internal::GetMapValueTypeId( - const MapValue& map_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kMap; } - - MapValue(const MapValue&) = delete; - MapValue(MapValue&&) = delete; - - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; - - // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; - - // Set lazily, by EnumValue::New. - Persistent type_; -}; - -// CEL_DECLARE_MAP_VALUE declares `map_value` as an map value. It must -// be part of the class definition of `map_value`. -// -// class MyMapValue : public cel::MapValue { -// ... -// private: -// CEL_DECLARE_MAP_VALUE(MyMapValue); -// }; -#define CEL_DECLARE_MAP_VALUE(map_value) \ - CEL_INTERNAL_DECLARE_VALUE(Map, map_value) - -// CEL_IMPLEMENT_MAP_VALUE implements `map_value` as an map -// value. It must be called after the class definition of `map_value`. -// -// class MyMapValue : public cel::MapValue { -// ... -// private: -// CEL_DECLARE_MAP_VALUE(MyMapValue); -// }; -// -// CEL_IMPLEMENT_MAP_VALUE(MyMapValue); -#define CEL_IMPLEMENT_MAP_VALUE(map_value) \ - CEL_INTERNAL_IMPLEMENT_VALUE(Map, map_value) - -// TypeValue represents an instance of cel::Type. -class TypeValue final : public Value, base_internal::ResourceInlined { - public: - Persistent type() const override; - - Kind kind() const override { return Kind::kType; } - - std::string DebugString() const override; - - Persistent value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kType; } - - // Called by `base_internal::ValueHandle` to construct value inline. - explicit TypeValue(Persistent type) : value_(std::move(type)) {} - - TypeValue() = delete; - - TypeValue(const TypeValue&) = default; - TypeValue(TypeValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - - Persistent value_; -}; - } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.cc b/base/value_factory.cc index 26fdd5a26..078294442 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -40,8 +40,20 @@ using base_internal::PersistentHandleFactory; using base_internal::StringBytesValue; using base_internal::StringStringValue; +template +bool CanPerformZeroCopy(MemoryManager& memory_manager, + const Persistent& handle) { + return base_internal::IsManagedHandle(handle) && + std::addressof(memory_manager) == + std::addressof(base_internal::GetMemoryManager(handle)); +} + } // namespace +Persistent NullValue::Get(ValueFactory& value_factory) { + return value_factory.GetNullValue(); +} + Persistent ValueFactory::GetNullValue() { return Persistent( PersistentHandleFactory::MakeUnmanaged( @@ -59,6 +71,135 @@ Persistent ValueFactory::CreateErrorValue( std::move(status)); } +Persistent BoolValue::False(ValueFactory& value_factory) { + return value_factory.CreateBoolValue(false); +} + +Persistent BoolValue::True(ValueFactory& value_factory) { + return value_factory.CreateBoolValue(true); +} + +Persistent DoubleValue::NaN(ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + std::numeric_limits::quiet_NaN()); +} + +Persistent DoubleValue::PositiveInfinity( + ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + std::numeric_limits::infinity()); +} + +Persistent DoubleValue::NegativeInfinity( + ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + -std::numeric_limits::infinity()); +} + +Persistent DurationValue::Zero( + ValueFactory& value_factory) { + // Should never fail, tests assert this. + return value_factory.CreateDurationValue(absl::ZeroDuration()).value(); +} + +Persistent TimestampValue::UnixEpoch( + ValueFactory& value_factory) { + // Should never fail, tests assert this. + return value_factory.CreateTimestampValue(absl::UnixEpoch()).value(); +} + +Persistent StringValue::Empty(ValueFactory& value_factory) { + return value_factory.GetStringValue(); +} + +absl::StatusOr> StringValue::Concat( + ValueFactory& value_factory, const Persistent& lhs, + const Persistent& rhs) { + absl::Cord cord; + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); + size_t size = 0; + size_t lhs_size = lhs->size_.load(std::memory_order_relaxed); + if (lhs_size != 0 && !lhs->empty()) { + size_t rhs_size = rhs->size_.load(std::memory_order_relaxed); + if (rhs_size != 0 && !rhs->empty()) { + size = lhs_size + rhs_size; + } + } + return value_factory.CreateStringValue(std::move(cord), size); +} + +Persistent BytesValue::Empty(ValueFactory& value_factory) { + return value_factory.GetBytesValue(); +} + +absl::StatusOr> BytesValue::Concat( + ValueFactory& value_factory, const Persistent& lhs, + const Persistent& rhs) { + absl::Cord cord; + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); + return value_factory.CreateBytesValue(std::move(cord)); +} + +struct EnumType::NewInstanceVisitor final { + const Persistent& enum_type; + ValueFactory& value_factory; + + absl::StatusOr> operator()( + absl::string_view name) const { + TypedEnumValueFactory factory(value_factory, enum_type); + return enum_type->NewInstanceByName(factory, name); + } + + absl::StatusOr> operator()(int64_t number) const { + TypedEnumValueFactory factory(value_factory, enum_type); + return enum_type->NewInstanceByNumber(factory, number); + } +}; + +absl::StatusOr> EnumValue::New( + const Persistent& enum_type, ValueFactory& value_factory, + EnumType::ConstantId id) { + CEL_ASSIGN_OR_RETURN( + auto enum_value, + absl::visit(EnumType::NewInstanceVisitor{enum_type, value_factory}, + id.data_)); + if (!enum_value->type_) { + // In case somebody is caching, we avoid setting the type_ if it has already + // been set, to avoid a race condition where one CPU sees a half written + // pointer. + const_cast(*enum_value).type_ = enum_type; + } + return enum_value; +} + +absl::StatusOr> StructValue::New( + const Persistent& struct_type, + ValueFactory& value_factory) { + TypedStructValueFactory factory(value_factory, struct_type); + CEL_ASSIGN_OR_RETURN(auto struct_value, struct_type->NewInstance(factory)); + if (!struct_value->type_) { + // In case somebody is caching, we avoid setting the type_ if it has already + // been set, to avoid a race condition where one CPU sees a half written + // pointer. + const_cast(*struct_value).type_ = struct_type; + } + return struct_value; +} + Persistent ValueFactory::CreateBoolValue(bool value) { return PersistentHandleFactory::Make(value); } diff --git a/base/value_factory.h b/base/value_factory.h index ad13b750b..359e5ef6b 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -31,6 +31,21 @@ #include "base/memory_manager.h" #include "base/type_manager.h" #include "base/value.h" +#include "base/values/bool_value.h" +#include "base/values/bytes_value.h" +#include "base/values/double_value.h" +#include "base/values/duration_value.h" +#include "base/values/enum_value.h" +#include "base/values/error_value.h" +#include "base/values/int_value.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" +#include "base/values/null_value.h" +#include "base/values/string_value.h" +#include "base/values/struct_value.h" +#include "base/values/timestamp_value.h" +#include "base/values/type_value.h" +#include "base/values/uint_value.h" namespace cel { diff --git a/base/value_test.cc b/base/value_test.cc index b4ee3aaa7..926d7472a 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -533,7 +533,7 @@ TEST_P(ValueTest, DefaultConstructor) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); Persistent value; - EXPECT_EQ(value, value_factory.GetNullValue()); + EXPECT_FALSE(value); } struct ConstructionAssignmentTestCase final { @@ -564,7 +564,7 @@ TEST_P(ConstructionAssignmentTest, MoveConstructor) { test_case().default_value(type_factory, value_factory)); Persistent to(std::move(from)); IS_INITIALIZED(from); - EXPECT_EQ(from, value_factory.GetNullValue()); + EXPECT_FALSE(from); EXPECT_EQ(to, test_case().default_value(type_factory, value_factory)); } @@ -588,7 +588,7 @@ TEST_P(ConstructionAssignmentTest, MoveAssignment) { Persistent to; to = std::move(from); IS_INITIALIZED(from); - EXPECT_EQ(from, value_factory.GetNullValue()); + EXPECT_FALSE(from); EXPECT_EQ(to, test_case().default_value(type_factory, value_factory)); } diff --git a/base/values/BUILD b/base/values/BUILD new file mode 100644 index 000000000..f693cedfc --- /dev/null +++ b/base/values/BUILD @@ -0,0 +1,259 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "bool", + srcs = ["bool_value.cc"], + hdrs = ["bool_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:bool", + "//internal:casts", + "@com_google_absl//absl/hash", + ], +) + +cc_library( + name = "bytes", + srcs = ["bytes_value.cc"], + hdrs = ["bytes_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:bytes", + "//internal:casts", + "//internal:strings", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_library( + name = "double", + srcs = ["double_value.cc"], + hdrs = ["double_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:double", + "//internal:casts", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "duration", + srcs = ["duration_value.cc"], + hdrs = ["duration_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:duration", + "//internal:casts", + "//internal:time", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "enum", + srcs = ["enum_value.cc"], + hdrs = ["enum_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:enum", + "//internal:casts", + "//internal:rtti", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "error", + srcs = ["error_value.cc"], + hdrs = ["error_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:error", + "//internal:casts", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "int", + srcs = ["int_value.cc"], + hdrs = ["int_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:int", + "//internal:casts", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "list", + srcs = ["list_value.cc"], + hdrs = ["list_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:list", + "//internal:rtti", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "map", + srcs = ["map_value.cc"], + hdrs = ["map_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:map", + "//internal:rtti", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "null", + srcs = ["null_value.cc"], + hdrs = ["null_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:null", + "//internal:no_destructor", + "@com_google_absl//absl/hash", + ], +) + +cc_library( + name = "string", + srcs = ["string_value.cc"], + hdrs = ["string_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:string", + "//internal:casts", + "//internal:strings", + "//internal:utf8", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_library( + name = "struct", + srcs = ["struct_value.cc"], + hdrs = ["struct_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:struct", + "//internal:casts", + "//internal:rtti", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "timestamp", + srcs = ["timestamp_value.cc"], + hdrs = ["timestamp_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:timestamp", + "//internal:casts", + "//internal:time", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "type", + srcs = ["type_value.cc"], + hdrs = ["type_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:type", + "//internal:casts", + "@com_google_absl//absl/hash", + ], +) + +cc_library( + name = "uint", + srcs = ["uint_value.cc"], + hdrs = ["uint_value.h"], + deps = [ + "//base:kind", + "//base:type", + "//base:value", + "//base/types:uint", + "//internal:casts", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) diff --git a/base/values/bool_value.cc b/base/values/bool_value.cc new file mode 100644 index 000000000..6cec63cf9 --- /dev/null +++ b/base/values/bool_value.cc @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/bool_value.h" + +#include +#include + +#include "base/types/bool_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent BoolValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + BoolType::Get()); +} + +std::string BoolValue::DebugString() const { + return value() ? "true" : "false"; +} + +void BoolValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(BoolValue, *this, address); +} + +void BoolValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(BoolValue, *this, address); +} + +bool BoolValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void BoolValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/bool_value.h b/base/values/bool_value.h new file mode 100644 index 000000000..5b34f7e1c --- /dev/null +++ b/base/values/bool_value.h @@ -0,0 +1,71 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_BOOL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_BOOL_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class BoolValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent False(ValueFactory& value_factory); + + static Persistent True(ValueFactory& value_factory); + + Persistent type() const override; + + Kind kind() const override { return Kind::kBool; } + + std::string DebugString() const override; + + constexpr bool value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kBool; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit BoolValue(bool value) : value_(value) {} + + BoolValue() = delete; + + BoolValue(const BoolValue&) = default; + BoolValue(BoolValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + bool value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_BOOL_VALUE_H_ diff --git a/base/values/bytes_value.cc b/base/values/bytes_value.cc new file mode 100644 index 000000000..d3f0d739c --- /dev/null +++ b/base/values/bytes_value.cc @@ -0,0 +1,306 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/bytes_value.h" + +#include +#include + +#include "base/types/bytes_type.h" +#include "internal/casts.h" +#include "internal/strings.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +struct BytesValueDebugStringVisitor final { + std::string operator()(absl::string_view value) const { + return internal::FormatBytesLiteral(value); + } + + std::string operator()(const absl::Cord& value) const { + return internal::FormatBytesLiteral(static_cast(value)); + } +}; + +struct ToStringVisitor final { + std::string operator()(absl::string_view value) const { + return std::string(value); + } + + std::string operator()(const absl::Cord& value) const { + return static_cast(value); + } +}; + +struct BytesValueSizeVisitor final { + size_t operator()(absl::string_view value) const { return value.size(); } + + size_t operator()(const absl::Cord& value) const { return value.size(); } +}; + +struct EmptyVisitor final { + bool operator()(absl::string_view value) const { return value.empty(); } + + bool operator()(const absl::Cord& value) const { return value.empty(); } +}; + +bool EqualsImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(absl::string_view lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +int CompareImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs.compare(rhs); +} + +int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { + return -rhs.Compare(lhs); +} + +int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs.Compare(rhs); +} + +int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs); +} + +template +class EqualsVisitor final { + public: + explicit EqualsVisitor(const T& ref) : ref_(ref) {} + + bool operator()(absl::string_view value) const { + return EqualsImpl(value, ref_); + } + + bool operator()(const absl::Cord& value) const { + return EqualsImpl(value, ref_); + } + + private: + const T& ref_; +}; + +template <> +class EqualsVisitor final { + public: + explicit EqualsVisitor(const BytesValue& ref) : ref_(ref) {} + + bool operator()(absl::string_view value) const { return ref_.Equals(value); } + + bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } + + private: + const BytesValue& ref_; +}; + +template +class CompareVisitor final { + public: + explicit CompareVisitor(const T& ref) : ref_(ref) {} + + int operator()(absl::string_view value) const { + return CompareImpl(value, ref_); + } + + int operator()(const absl::Cord& value) const { + return CompareImpl(value, ref_); + } + + private: + const T& ref_; +}; + +template <> +class CompareVisitor final { + public: + explicit CompareVisitor(const BytesValue& ref) : ref_(ref) {} + + int operator()(const absl::Cord& value) const { return ref_.Compare(value); } + + int operator()(absl::string_view value) const { return ref_.Compare(value); } + + private: + const BytesValue& ref_; +}; + +class HashValueVisitor final { + public: + explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} + + void operator()(absl::string_view value) { + absl::HashState::combine(std::move(state_), value); + } + + void operator()(const absl::Cord& value) { + absl::HashState::combine(std::move(state_), value); + } + + private: + absl::HashState state_; +}; + +} // namespace + +Persistent BytesValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + BytesType::Get()); +} + +size_t BytesValue::size() const { + return absl::visit(BytesValueSizeVisitor{}, rep()); +} + +bool BytesValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } + +bool BytesValue::Equals(absl::string_view bytes) const { + return absl::visit(EqualsVisitor(bytes), rep()); +} + +bool BytesValue::Equals(const absl::Cord& bytes) const { + return absl::visit(EqualsVisitor(bytes), rep()); +} + +bool BytesValue::Equals(const Persistent& bytes) const { + return absl::visit(EqualsVisitor(*this), bytes->rep()); +} + +int BytesValue::Compare(absl::string_view bytes) const { + return absl::visit(CompareVisitor(bytes), rep()); +} + +int BytesValue::Compare(const absl::Cord& bytes) const { + return absl::visit(CompareVisitor(bytes), rep()); +} + +int BytesValue::Compare(const Persistent& bytes) const { + return absl::visit(CompareVisitor(*this), bytes->rep()); +} + +std::string BytesValue::ToString() const { + return absl::visit(ToStringVisitor{}, rep()); +} + +std::string BytesValue::DebugString() const { + return absl::visit(BytesValueDebugStringVisitor{}, rep()); +} + +bool BytesValue::Equals(const Value& other) const { + return kind() == other.kind() && + absl::visit(EqualsVisitor(*this), + internal::down_cast(other).rep()); +} + +void BytesValue::HashValue(absl::HashState state) const { + absl::visit( + HashValueVisitor(absl::HashState::combine(std::move(state), type())), + rep()); +} + +namespace base_internal { + +absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return value_; +} + +void InlinedCordBytesValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(InlinedCordBytesValue, *this, address); +} + +void InlinedCordBytesValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(InlinedCordBytesValue, *this, address); +} + +typename InlinedCordBytesValue::Rep InlinedCordBytesValue::rep() const { + return Rep(absl::in_place_type>, + std::cref(value_)); +} + +absl::Cord InlinedStringViewBytesValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return absl::Cord(value_); +} + +void InlinedStringViewBytesValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(InlinedStringViewBytesValue, *this, address); +} + +void InlinedStringViewBytesValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(InlinedStringViewBytesValue, *this, address); +} + +typename InlinedStringViewBytesValue::Rep InlinedStringViewBytesValue::rep() + const { + return Rep(absl::in_place_type, value_); +} + +std::pair StringBytesValue::SizeAndAlignment() const { + return std::make_pair(sizeof(StringBytesValue), alignof(StringBytesValue)); +} + +absl::Cord StringBytesValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal(absl::string_view(value_), + [this]() { Unref(); }); + } + return absl::Cord(value_); +} + +typename StringBytesValue::Rep StringBytesValue::rep() const { + return Rep(absl::in_place_type, absl::string_view(value_)); +} + +std::pair ExternalDataBytesValue::SizeAndAlignment() const { + return std::make_pair(sizeof(ExternalDataBytesValue), + alignof(ExternalDataBytesValue)); +} + +absl::Cord ExternalDataBytesValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal( + absl::string_view(static_cast(value_.data), value_.size), + [this]() { Unref(); }); + } + return absl::Cord( + absl::string_view(static_cast(value_.data), value_.size)); +} + +typename ExternalDataBytesValue::Rep ExternalDataBytesValue::rep() const { + return Rep( + absl::in_place_type, + absl::string_view(static_cast(value_.data), value_.size)); +} + +} // namespace base_internal + +} // namespace cel diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h new file mode 100644 index 000000000..f257b003a --- /dev/null +++ b/base/values/bytes_value.h @@ -0,0 +1,205 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_BYTES_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_BYTES_VALUE_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class BytesValue : public Value { + protected: + using Rep = base_internal::BytesValueRep; + + public: + static Persistent Empty(ValueFactory& value_factory); + + // Concat concatenates the contents of two ByteValue, returning a new + // ByteValue. The resulting ByteValue is not tied to the lifetime of either of + // the input ByteValue. + static absl::StatusOr> Concat( + ValueFactory& value_factory, const Persistent& lhs, + const Persistent& rhs); + + Persistent type() const final; + + Kind kind() const final { return Kind::kBytes; } + + std::string DebugString() const final; + + size_t size() const; + + bool empty() const; + + bool Equals(absl::string_view bytes) const; + bool Equals(const absl::Cord& bytes) const; + bool Equals(const Persistent& bytes) const; + + int Compare(absl::string_view bytes) const; + int Compare(const absl::Cord& bytes) const; + int Compare(const Persistent& bytes) const; + + std::string ToString() const; + + absl::Cord ToCord() const { + // Without the handle we cannot know if this is reference counted. + return ToCord(/*reference_counted=*/false); + } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + friend class base_internal::InlinedCordBytesValue; + friend class base_internal::InlinedStringViewBytesValue; + friend class base_internal::StringBytesValue; + friend class base_internal::ExternalDataBytesValue; + friend base_internal::BytesValueRep interop_internal::GetBytesValueRep( + const Persistent& value); + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kBytes; } + + BytesValue() = default; + BytesValue(const BytesValue&) = default; + BytesValue(BytesValue&&) = default; + + // Get the contents of this BytesValue as absl::Cord. When reference_counted + // is true, the implementation can potentially return an absl::Cord that wraps + // the contents instead of copying. + virtual absl::Cord ToCord(bool reference_counted) const = 0; + + // Get the contents of this BytesValue as either absl::string_view or const + // absl::Cord&. + virtual Rep rep() const = 0; + + // See comments for respective member functions on `Value`. + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; +}; + +namespace base_internal { + +// Implementation of BytesValue that is stored inlined within a handle. Since +// absl::Cord is reference counted itself, this is more efficient than storing +// this on the heap. +class InlinedCordBytesValue final : public BytesValue, public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} + + InlinedCordBytesValue() = delete; + + InlinedCordBytesValue(const InlinedCordBytesValue&) = default; + InlinedCordBytesValue(InlinedCordBytesValue&&) = default; + + // See comments for respective member functions on `ByteValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::Cord value_; +}; + +// Implementation of BytesValue that is stored inlined within a handle. This +// class is inheritently unsafe and care should be taken when using it. +// Typically this should only be used for empty strings or data that is static +// and lives for the duration of a program. +class InlinedStringViewBytesValue final : public BytesValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedStringViewBytesValue(absl::string_view value) + : value_(value) {} + + InlinedStringViewBytesValue() = delete; + + InlinedStringViewBytesValue(const InlinedStringViewBytesValue&) = default; + InlinedStringViewBytesValue(InlinedStringViewBytesValue&&) = default; + + // See comments for respective member functions on `ByteValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::string_view value_; +}; + +// Implementation of BytesValue that uses std::string and is allocated on the +// heap, potentially reference counted. +class StringBytesValue final : public BytesValue { + private: + friend class cel::MemoryManager; + + explicit StringBytesValue(std::string value) : value_(std::move(value)) {} + + StringBytesValue() = delete; + StringBytesValue(const StringBytesValue&) = delete; + StringBytesValue(StringBytesValue&&) = delete; + + // See comments for respective member functions on `ByteValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + std::string value_; +}; + +// Implementation of BytesValue that wraps a contiguous array of bytes and calls +// the releaser when it is no longer needed. It is stored on the heap and +// potentially reference counted. +class ExternalDataBytesValue final : public BytesValue { + private: + friend class cel::MemoryManager; + + explicit ExternalDataBytesValue(ExternalData value) + : value_(std::move(value)) {} + + ExternalDataBytesValue() = delete; + ExternalDataBytesValue(const ExternalDataBytesValue&) = delete; + ExternalDataBytesValue(ExternalDataBytesValue&&) = delete; + + // See comments for respective member functions on `ByteValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + ExternalData value_; +}; + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_BYTES_VALUE_H_ diff --git a/base/values/double_value.cc b/base/values/double_value.cc new file mode 100644 index 000000000..a0fa86baf --- /dev/null +++ b/base/values/double_value.cc @@ -0,0 +1,83 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/double_value.h" + +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "base/types/double_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent DoubleValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + DoubleType::Get()); +} + +std::string DoubleValue::DebugString() const { + if (std::isfinite(value())) { + if (std::floor(value()) != value()) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(value()); + } + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64_t. + std::string stringified = absl::StrCat(value()); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } else { + // absl::StrCat has a decimal now? Use it directly. + } + return stringified; + } + if (std::isnan(value())) { + return "nan"; + } + if (std::signbit(value())) { + return "-infinity"; + } + return "+infinity"; +} + +void DoubleValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(DoubleValue, *this, address); +} + +void DoubleValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(DoubleValue, *this, address); +} + +bool DoubleValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void DoubleValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/double_value.h b/base/values/double_value.h new file mode 100644 index 000000000..c9aa6cb52 --- /dev/null +++ b/base/values/double_value.h @@ -0,0 +1,75 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_DOUBLE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_DOUBLE_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class DoubleValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent NaN(ValueFactory& value_factory); + + static Persistent PositiveInfinity( + ValueFactory& value_factory); + + static Persistent NegativeInfinity( + ValueFactory& value_factory); + + Persistent type() const override; + + Kind kind() const override { return Kind::kDouble; } + + std::string DebugString() const override; + + constexpr double value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kDouble; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit DoubleValue(double value) : value_(value) {} + + DoubleValue() = delete; + + DoubleValue(const DoubleValue&) = default; + DoubleValue(DoubleValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + double value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_DOUBLE_VALUE_H_ diff --git a/base/values/duration_value.cc b/base/values/duration_value.cc new file mode 100644 index 000000000..6e648cc29 --- /dev/null +++ b/base/values/duration_value.cc @@ -0,0 +1,58 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/duration_value.h" + +#include +#include + +#include "base/types/duration_type.h" +#include "internal/casts.h" +#include "internal/time.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent DurationValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + DurationType::Get()); +} + +std::string DurationValue::DebugString() const { + return internal::FormatDuration(value()).value(); +} + +void DurationValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(DurationValue, *this, address); +} + +void DurationValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(DurationValue, *this, address); +} + +bool DurationValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void DurationValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/duration_value.h b/base/values/duration_value.h new file mode 100644 index 000000000..21b5c8381 --- /dev/null +++ b/base/values/duration_value.h @@ -0,0 +1,71 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_DURATION_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_DURATION_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "absl/time/time.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class DurationValue final : public Value, + public base_internal::ResourceInlined { + public: + static Persistent Zero(ValueFactory& value_factory); + + Persistent type() const override; + + Kind kind() const override { return Kind::kDuration; } + + std::string DebugString() const override; + + constexpr absl::Duration value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kDuration; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit DurationValue(absl::Duration value) : value_(value) {} + + DurationValue() = delete; + + DurationValue(const DurationValue&) = default; + DurationValue(DurationValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + absl::Duration value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_DURATION_VALUE_H_ diff --git a/base/values/enum_value.cc b/base/values/enum_value.cc new file mode 100644 index 000000000..2457c5fb9 --- /dev/null +++ b/base/values/enum_value.cc @@ -0,0 +1,33 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/enum_value.h" + +#include +#include + +#include "internal/casts.h" + +namespace cel { + +bool EnumValue::Equals(const Value& other) const { + return kind() == other.kind() && type() == other.type() && + number() == internal::down_cast(other).number(); +} + +void EnumValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), number()); +} + +} // namespace cel diff --git a/base/values/enum_value.h b/base/values/enum_value.h new file mode 100644 index 000000000..186c54811 --- /dev/null +++ b/base/values/enum_value.h @@ -0,0 +1,116 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_ENUM_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_ENUM_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/types/enum_type.h" +#include "base/value.h" +#include "internal/rtti.h" + +namespace cel { + +class ValueFactory; + +// EnumValue represents a single constant belonging to cel::EnumType. +class EnumValue : public Value { + public: + static absl::StatusOr> New( + const Persistent& enum_type, ValueFactory& value_factory, + EnumType::ConstantId id); + + Persistent type() const final { return type_; } + + Kind kind() const final { return Kind::kEnum; } + + virtual int64_t number() const = 0; + + virtual absl::string_view name() const = 0; + + protected: + explicit EnumValue(const Persistent& type) : type_(type) { + ABSL_ASSERT(type_); + } + + private: + friend internal::TypeInfo base_internal::GetEnumValueTypeId( + const EnumValue& enum_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kEnum; } + + EnumValue(const EnumValue&) = delete; + EnumValue(EnumValue&&) = delete; + + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + Persistent type_; +}; + +// CEL_DECLARE_ENUM_VALUE declares `enum_value` as an enumeration value. It must +// be part of the class definition of `enum_value`. +// +// class MyEnumValue : public cel::EnumValue { +// ... +// private: +// CEL_DECLARE_ENUM_VALUE(MyEnumValue); +// }; +#define CEL_DECLARE_ENUM_VALUE(enum_value) \ + CEL_INTERNAL_DECLARE_VALUE(Enum, enum_value) + +// CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It +// must be called after the class definition of `enum_value`. +// +// class MyEnumValue : public cel::EnumValue { +// ... +// private: +// CEL_DECLARE_ENUM_VALUE(MyEnumValue); +// }; +// +// CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); +#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Enum, enum_value) + +namespace base_internal { + +inline internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value) { + return enum_value.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_ENUM_VALUE_H_ diff --git a/base/values/error_value.cc b/base/values/error_value.cc new file mode 100644 index 000000000..bad4884f6 --- /dev/null +++ b/base/values/error_value.cc @@ -0,0 +1,86 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/error_value.h" + +#include +#include + +#include "base/types/error_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +struct StatusPayload final { + std::string key; + absl::Cord value; +}; + +void StatusHashValue(absl::HashState state, const absl::Status& status) { + // absl::Status::operator== compares `raw_code()`, `message()` and the + // payloads. + state = absl::HashState::combine(std::move(state), status.raw_code(), + status.message()); + // In order to determistically hash, we need to put the payloads in sorted + // order. There is no guarantee from `absl::Status` on the order of the + // payloads returned from `absl::Status::ForEachPayload`. + // + // This should be the same inline size as + // `absl::status_internal::StatusPayloads`. + absl::InlinedVector payloads; + status.ForEachPayload([&](absl::string_view key, const absl::Cord& value) { + payloads.push_back(StatusPayload{std::string(key), value}); + }); + std::stable_sort( + payloads.begin(), payloads.end(), + [](const StatusPayload& lhs, const StatusPayload& rhs) -> bool { + return lhs.key < rhs.key; + }); + for (const auto& payload : payloads) { + state = + absl::HashState::combine(std::move(state), payload.key, payload.value); + } +} + +} // namespace + +Persistent ErrorValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + ErrorType::Get()); +} + +std::string ErrorValue::DebugString() const { return value().ToString(); } + +void ErrorValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(ErrorValue, *this, address); +} + +void ErrorValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(ErrorValue, *this, address); +} + +bool ErrorValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void ErrorValue::HashValue(absl::HashState state) const { + StatusHashValue(absl::HashState::combine(std::move(state), type()), value()); +} + +} // namespace cel diff --git a/base/values/error_value.h b/base/values/error_value.h new file mode 100644 index 000000000..5b0888c17 --- /dev/null +++ b/base/values/error_value.h @@ -0,0 +1,67 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ErrorValue final : public Value, public base_internal::ResourceInlined { + public: + Persistent type() const override; + + Kind kind() const override { return Kind::kError; } + + std::string DebugString() const override; + + const absl::Status& value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kError; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit ErrorValue(absl::Status value) : value_(std::move(value)) {} + + ErrorValue() = delete; + + ErrorValue(const ErrorValue&) = default; + ErrorValue(ErrorValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + absl::Status value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ diff --git a/base/values/int_value.cc b/base/values/int_value.cc new file mode 100644 index 000000000..a1588ffec --- /dev/null +++ b/base/values/int_value.cc @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/int_value.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "base/types/int_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent IntValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + IntType::Get()); +} + +std::string IntValue::DebugString() const { return absl::StrCat(value()); } + +void IntValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(IntValue, *this, address); +} + +void IntValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(IntValue, *this, address); +} + +bool IntValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void IntValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/int_value.h b/base/values/int_value.h new file mode 100644 index 000000000..f2fb17a08 --- /dev/null +++ b/base/values/int_value.h @@ -0,0 +1,68 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_INT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_INT_VALUE_H_ + +#include +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class IntValue; + +class IntValue final : public Value, public base_internal::ResourceInlined { + public: + Persistent type() const override; + + Kind kind() const override { return Kind::kInt; } + + std::string DebugString() const override; + + constexpr int64_t value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kInt; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit IntValue(int64_t value) : value_(value) {} + + IntValue() = delete; + + IntValue(const IntValue&) = default; + IntValue(IntValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + int64_t value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_INT_VALUE_H_ diff --git a/base/values/list_value.cc b/base/values/list_value.cc new file mode 100644 index 000000000..4af7217e0 --- /dev/null +++ b/base/values/list_value.cc @@ -0,0 +1,21 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/list_value.h" + +namespace cel { + +// + +} // namespace cel diff --git a/base/values/list_value.h b/base/values/list_value.h new file mode 100644 index 000000000..abe5b630f --- /dev/null +++ b/base/values/list_value.h @@ -0,0 +1,118 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_LIST_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_LIST_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/types/list_type.h" +#include "base/value.h" +#include "internal/rtti.h" + +namespace cel { + +class ValueFactory; + +// ListValue represents an instance of cel::ListType. +class ListValue : public Value { + public: + // TODO(issues/5): implement iterators so we can have cheap concated lists + + Persistent type() const final { return type_; } + + Kind kind() const final { return Kind::kList; } + + virtual size_t size() const = 0; + + virtual bool empty() const { return size() == 0; } + + virtual absl::StatusOr> Get( + ValueFactory& value_factory, size_t index) const = 0; + + protected: + explicit ListValue(const Persistent& type) : type_(type) {} + + private: + friend internal::TypeInfo base_internal::GetListValueTypeId( + const ListValue& list_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kList; } + + ListValue(const ListValue&) = delete; + ListValue(ListValue&&) = delete; + + // TODO(issues/5): I do not like this, we should have these two take a + // ValueFactory and return absl::StatusOr and absl::Status. We support + // lazily created values, so errors can occur during equality testing. + // Especially if there are different value implementations for the same type. + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + const Persistent type_; +}; + +// CEL_DECLARE_LIST_VALUE declares `list_value` as an list value. It must +// be part of the class definition of `list_value`. +// +// class MyListValue : public cel::ListValue { +// ... +// private: +// CEL_DECLARE_LIST_VALUE(MyListValue); +// }; +#define CEL_DECLARE_LIST_VALUE(list_value) \ + CEL_INTERNAL_DECLARE_VALUE(List, list_value) + +// CEL_IMPLEMENT_LIST_VALUE implements `list_value` as an list +// value. It must be called after the class definition of `list_value`. +// +// class MyListValue : public cel::ListValue { +// ... +// private: +// CEL_DECLARE_LIST_VALUE(MyListValue); +// }; +// +// CEL_IMPLEMENT_LIST_VALUE(MyListValue); +#define CEL_IMPLEMENT_LIST_VALUE(list_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(List, list_value) + +namespace base_internal { + +inline internal::TypeInfo GetListValueTypeId(const ListValue& list_value) { + return list_value.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_LIST_VALUE_H_ diff --git a/base/values/map_value.cc b/base/values/map_value.cc new file mode 100644 index 000000000..37c6405d0 --- /dev/null +++ b/base/values/map_value.cc @@ -0,0 +1,21 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/map_value.h" + +namespace cel { + +// + +} // namespace cel diff --git a/base/values/map_value.h b/base/values/map_value.h new file mode 100644 index 000000000..db907feac --- /dev/null +++ b/base/values/map_value.h @@ -0,0 +1,117 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/types/map_type.h" +#include "base/value.h" +#include "internal/rtti.h" + +namespace cel { + +class ValueFactory; + +// MapValue represents an instance of cel::MapType. +class MapValue : public Value { + public: + Persistent type() const final { return type_; } + + Kind kind() const final { return Kind::kMap; } + + virtual size_t size() const = 0; + + virtual bool empty() const { return size() == 0; } + + virtual absl::StatusOr> Get( + ValueFactory& value_factory, + const Persistent& key) const = 0; + + virtual absl::StatusOr Has( + const Persistent& key) const = 0; + + protected: + explicit MapValue(const Persistent& type) : type_(type) {} + + private: + friend internal::TypeInfo base_internal::GetMapValueTypeId( + const MapValue& map_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kMap; } + + MapValue(const MapValue&) = delete; + MapValue(MapValue&&) = delete; + + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + // Set lazily, by EnumValue::New. + Persistent type_; +}; + +// CEL_DECLARE_MAP_VALUE declares `map_value` as an map value. It must +// be part of the class definition of `map_value`. +// +// class MyMapValue : public cel::MapValue { +// ... +// private: +// CEL_DECLARE_MAP_VALUE(MyMapValue); +// }; +#define CEL_DECLARE_MAP_VALUE(map_value) \ + CEL_INTERNAL_DECLARE_VALUE(Map, map_value) + +// CEL_IMPLEMENT_MAP_VALUE implements `map_value` as an map +// value. It must be called after the class definition of `map_value`. +// +// class MyMapValue : public cel::MapValue { +// ... +// private: +// CEL_DECLARE_MAP_VALUE(MyMapValue); +// }; +// +// CEL_IMPLEMENT_MAP_VALUE(MyMapValue); +#define CEL_IMPLEMENT_MAP_VALUE(map_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Map, map_value) + +namespace base_internal { + +inline internal::TypeInfo GetMapValueTypeId(const MapValue& map_value) { + return map_value.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_H_ diff --git a/base/values/null_value.cc b/base/values/null_value.cc new file mode 100644 index 000000000..87f90d368 --- /dev/null +++ b/base/values/null_value.cc @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/null_value.h" + +#include +#include + +#include "base/types/null_type.h" +#include "internal/no_destructor.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent NullValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + NullType::Get()); +} + +std::string NullValue::DebugString() const { return "null"; } + +const NullValue& NullValue::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +void NullValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(NullValue, *this, address); +} + +void NullValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(NullValue, *this, address); +} + +bool NullValue::Equals(const Value& other) const { + return kind() == other.kind(); +} + +void NullValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), 0); +} + +} // namespace cel diff --git a/base/values/null_value.h b/base/values/null_value.h new file mode 100644 index 000000000..53098c059 --- /dev/null +++ b/base/values/null_value.h @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_NULL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_NULL_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class NullValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent Get(ValueFactory& value_factory); + + Persistent type() const override; + + Kind kind() const override { return Kind::kNullType; } + + std::string DebugString() const override; + + // Note GCC does not consider a friend member as a member of a friend. + ABSL_ATTRIBUTE_PURE_FUNCTION static const NullValue& Get(); + + bool Equals(const Value& other) const override; + + void HashValue(absl::HashState state) const override; + + private: + friend class ValueFactory; + template + friend class internal::NoDestructor; + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kNullType; } + + NullValue() = default; + NullValue(const NullValue&) = default; + NullValue(NullValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_NULL_VALUE_H_ diff --git a/base/values/string_value.cc b/base/values/string_value.cc new file mode 100644 index 000000000..77fb5c437 --- /dev/null +++ b/base/values/string_value.cc @@ -0,0 +1,320 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/string_value.h" + +#include +#include + +#include "base/types/string_type.h" +#include "internal/casts.h" +#include "internal/strings.h" +#include "internal/utf8.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +struct StringValueDebugStringVisitor final { + std::string operator()(absl::string_view value) const { + return internal::FormatStringLiteral(value); + } + + std::string operator()(const absl::Cord& value) const { + return internal::FormatStringLiteral(static_cast(value)); + } +}; + +struct ToStringVisitor final { + std::string operator()(absl::string_view value) const { + return std::string(value); + } + + std::string operator()(const absl::Cord& value) const { + return static_cast(value); + } +}; + +struct StringValueSizeVisitor final { + size_t operator()(absl::string_view value) const { + return internal::Utf8CodePointCount(value); + } + + size_t operator()(const absl::Cord& value) const { + return internal::Utf8CodePointCount(value); + } +}; + +struct EmptyVisitor final { + bool operator()(absl::string_view value) const { return value.empty(); } + + bool operator()(const absl::Cord& value) const { return value.empty(); } +}; + +bool EqualsImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(absl::string_view lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +int CompareImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs.compare(rhs); +} + +int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { + return -rhs.Compare(lhs); +} + +int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs.Compare(rhs); +} + +int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs); +} + +template +class EqualsVisitor final { + public: + explicit EqualsVisitor(const T& ref) : ref_(ref) {} + + bool operator()(absl::string_view value) const { + return EqualsImpl(value, ref_); + } + + bool operator()(const absl::Cord& value) const { + return EqualsImpl(value, ref_); + } + + private: + const T& ref_; +}; + +template <> +class EqualsVisitor final { + public: + explicit EqualsVisitor(const StringValue& ref) : ref_(ref) {} + + bool operator()(absl::string_view value) const { return ref_.Equals(value); } + + bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } + + private: + const StringValue& ref_; +}; + +template +class CompareVisitor final { + public: + explicit CompareVisitor(const T& ref) : ref_(ref) {} + + int operator()(absl::string_view value) const { + return CompareImpl(value, ref_); + } + + int operator()(const absl::Cord& value) const { + return CompareImpl(value, ref_); + } + + private: + const T& ref_; +}; + +template <> +class CompareVisitor final { + public: + explicit CompareVisitor(const StringValue& ref) : ref_(ref) {} + + int operator()(const absl::Cord& value) const { return ref_.Compare(value); } + + int operator()(absl::string_view value) const { return ref_.Compare(value); } + + private: + const StringValue& ref_; +}; + +class HashValueVisitor final { + public: + explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} + + void operator()(absl::string_view value) { + absl::HashState::combine(std::move(state_), value); + } + + void operator()(const absl::Cord& value) { + absl::HashState::combine(std::move(state_), value); + } + + private: + absl::HashState state_; +}; + +} // namespace + +Persistent StringValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + StringType::Get()); +} + +size_t StringValue::size() const { + // We lazily calculate the code point count in some circumstances. If the code + // point count is 0 and the underlying rep is not empty we need to actually + // calculate the size. It is okay if this is done by multiple threads + // simultaneously, it is a benign race. + size_t size = size_.load(std::memory_order_relaxed); + if (size == 0 && !empty()) { + size = absl::visit(StringValueSizeVisitor{}, rep()); + size_.store(size, std::memory_order_relaxed); + } + return size; +} + +bool StringValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } + +bool StringValue::Equals(absl::string_view string) const { + return absl::visit(EqualsVisitor(string), rep()); +} + +bool StringValue::Equals(const absl::Cord& string) const { + return absl::visit(EqualsVisitor(string), rep()); +} + +bool StringValue::Equals(const Persistent& string) const { + return absl::visit(EqualsVisitor(*this), string->rep()); +} + +int StringValue::Compare(absl::string_view string) const { + return absl::visit(CompareVisitor(string), rep()); +} + +int StringValue::Compare(const absl::Cord& string) const { + return absl::visit(CompareVisitor(string), rep()); +} + +int StringValue::Compare(const Persistent& string) const { + return absl::visit(CompareVisitor(*this), string->rep()); +} + +std::string StringValue::ToString() const { + return absl::visit(ToStringVisitor{}, rep()); +} + +std::string StringValue::DebugString() const { + return absl::visit(StringValueDebugStringVisitor{}, rep()); +} + +bool StringValue::Equals(const Value& other) const { + return kind() == other.kind() && + absl::visit(EqualsVisitor(*this), + internal::down_cast(other).rep()); +} + +void StringValue::HashValue(absl::HashState state) const { + absl::visit( + HashValueVisitor(absl::HashState::combine(std::move(state), type())), + rep()); +} + +namespace base_internal { + +absl::Cord InlinedCordStringValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return value_; +} + +void InlinedCordStringValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(InlinedCordStringValue, *this, address); +} + +void InlinedCordStringValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(InlinedCordStringValue, *this, address); +} + +typename InlinedCordStringValue::Rep InlinedCordStringValue::rep() const { + return Rep(absl::in_place_type>, + std::cref(value_)); +} + +absl::Cord InlinedStringViewStringValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return absl::Cord(value_); +} + +void InlinedStringViewStringValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(InlinedStringViewStringValue, *this, address); +} + +void InlinedStringViewStringValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(InlinedStringViewStringValue, *this, address); +} + +typename InlinedStringViewStringValue::Rep InlinedStringViewStringValue::rep() + const { + return Rep(absl::in_place_type, value_); +} + +std::pair StringStringValue::SizeAndAlignment() const { + return std::make_pair(sizeof(StringStringValue), alignof(StringStringValue)); +} + +absl::Cord StringStringValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal(absl::string_view(value_), + [this]() { Unref(); }); + } + return absl::Cord(value_); +} + +typename StringStringValue::Rep StringStringValue::rep() const { + return Rep(absl::in_place_type, absl::string_view(value_)); +} + +std::pair ExternalDataStringValue::SizeAndAlignment() const { + return std::make_pair(sizeof(ExternalDataStringValue), + alignof(ExternalDataStringValue)); +} + +absl::Cord ExternalDataStringValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal( + absl::string_view(static_cast(value_.data), value_.size), + [this]() { Unref(); }); + } + return absl::Cord( + absl::string_view(static_cast(value_.data), value_.size)); +} + +typename ExternalDataStringValue::Rep ExternalDataStringValue::rep() const { + return Rep( + absl::in_place_type, + absl::string_view(static_cast(value_.data), value_.size)); +} + +} // namespace base_internal + +} // namespace cel diff --git a/base/values/string_value.h b/base/values/string_value.h new file mode 100644 index 000000000..edd07381b --- /dev/null +++ b/base/values/string_value.h @@ -0,0 +1,226 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_STRING_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_STRING_VALUE_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class StringValue : public Value { + protected: + using Rep = base_internal::StringValueRep; + + public: + static Persistent Empty(ValueFactory& value_factory); + + static absl::StatusOr> Concat( + ValueFactory& value_factory, const Persistent& lhs, + const Persistent& rhs); + + Persistent type() const final; + + Kind kind() const final { return Kind::kString; } + + std::string DebugString() const final; + + size_t size() const; + + bool empty() const; + + bool Equals(absl::string_view string) const; + bool Equals(const absl::Cord& string) const; + bool Equals(const Persistent& string) const; + + int Compare(absl::string_view string) const; + int Compare(const absl::Cord& string) const; + int Compare(const Persistent& string) const; + + std::string ToString() const; + + absl::Cord ToCord() const { + // Without the handle we cannot know if this is reference counted. + return ToCord(/*reference_counted=*/false); + } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + friend class base_internal::InlinedCordStringValue; + friend class base_internal::InlinedStringViewStringValue; + friend class base_internal::StringStringValue; + friend class base_internal::ExternalDataStringValue; + friend base_internal::StringValueRep interop_internal::GetStringValueRep( + const Persistent& value); + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kString; } + + explicit StringValue(size_t size) : size_(size) {} + + StringValue() = default; + + StringValue(const StringValue& other) + : StringValue(other.size_.load(std::memory_order_relaxed)) {} + + StringValue(StringValue&& other) + : StringValue(other.size_.exchange(0, std::memory_order_relaxed)) {} + + // Get the contents of this BytesValue as absl::Cord. When reference_counted + // is true, the implementation can potentially return an absl::Cord that wraps + // the contents instead of copying. + virtual absl::Cord ToCord(bool reference_counted) const = 0; + + // Get the contents of this StringValue as either absl::string_view or const + // absl::Cord&. + virtual Rep rep() const = 0; + + // See comments for respective member functions on `Value`. + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; + + // Lazily cached code point count. + mutable std::atomic size_ = 0; +}; + +namespace base_internal { + +// Implementation of StringValue that is stored inlined within a handle. Since +// absl::Cord is reference counted itself, this is more efficient then storing +// this on the heap. +class InlinedCordStringValue final : public StringValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedCordStringValue(absl::Cord value) + : InlinedCordStringValue(0, std::move(value)) {} + + InlinedCordStringValue(size_t size, absl::Cord value) + : StringValue(size), value_(std::move(value)) {} + + InlinedCordStringValue() = delete; + + InlinedCordStringValue(const InlinedCordStringValue&) = default; + InlinedCordStringValue(InlinedCordStringValue&&) = default; + + // See comments for respective member functions on `StringValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::Cord value_; +}; + +// Implementation of StringValue that is stored inlined within a handle. This +// class is inheritently unsafe and care should be taken when using it. +// Typically this should only be used for empty strings or data that is static +// and lives for the duration of a program. +class InlinedStringViewStringValue final : public StringValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedStringViewStringValue(absl::string_view value) + : InlinedStringViewStringValue(0, value) {} + + InlinedStringViewStringValue(size_t size, absl::string_view value) + : StringValue(size), value_(value) {} + + InlinedStringViewStringValue() = delete; + + InlinedStringViewStringValue(const InlinedStringViewStringValue&) = default; + InlinedStringViewStringValue(InlinedStringViewStringValue&&) = default; + + // See comments for respective member functions on `StringValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::string_view value_; +}; + +// Implementation of StringValue that uses std::string and is allocated on the +// heap, potentially reference counted. +class StringStringValue final : public StringValue { + private: + friend class cel::MemoryManager; + + explicit StringStringValue(std::string value) + : StringStringValue(0, std::move(value)) {} + + StringStringValue(size_t size, std::string value) + : StringValue(size), value_(std::move(value)) {} + + StringStringValue() = delete; + StringStringValue(const StringStringValue&) = delete; + StringStringValue(StringStringValue&&) = delete; + + // See comments for respective member functions on `StringValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + std::string value_; +}; + +// Implementation of StringValue that wraps a contiguous array of bytes and +// calls the releaser when it is no longer needed. It is stored on the heap and +// potentially reference counted. +class ExternalDataStringValue final : public StringValue { + private: + friend class cel::MemoryManager; + + explicit ExternalDataStringValue(ExternalData value) + : ExternalDataStringValue(0, std::move(value)) {} + + ExternalDataStringValue(size_t size, ExternalData value) + : StringValue(size), value_(std::move(value)) {} + + ExternalDataStringValue() = delete; + ExternalDataStringValue(const ExternalDataStringValue&) = delete; + ExternalDataStringValue(ExternalDataStringValue&&) = delete; + + // See comments for respective member functions on `StringValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + ExternalData value_; +}; + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_STRING_VALUE_H_ diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc new file mode 100644 index 000000000..7ca2e21b5 --- /dev/null +++ b/base/values/struct_value.cc @@ -0,0 +1,77 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/struct_value.h" + +#include +#include + +#include "base/types/struct_type.h" + +namespace cel { + +struct StructValue::SetFieldVisitor final { + StructValue& struct_value; + const Persistent& value; + + absl::Status operator()(absl::string_view name) const { + return struct_value.SetFieldByName(name, value); + } + + absl::Status operator()(int64_t number) const { + return struct_value.SetFieldByNumber(number, value); + } +}; + +struct StructValue::GetFieldVisitor final { + const StructValue& struct_value; + ValueFactory& value_factory; + + absl::StatusOr> operator()( + absl::string_view name) const { + return struct_value.GetFieldByName(value_factory, name); + } + + absl::StatusOr> operator()(int64_t number) const { + return struct_value.GetFieldByNumber(value_factory, number); + } +}; + +struct StructValue::HasFieldVisitor final { + const StructValue& struct_value; + + absl::StatusOr operator()(absl::string_view name) const { + return struct_value.HasFieldByName(name); + } + + absl::StatusOr operator()(int64_t number) const { + return struct_value.HasFieldByNumber(number); + } +}; + +absl::Status StructValue::SetField(FieldId field, + const Persistent& value) { + return absl::visit(SetFieldVisitor{*this, value}, field.data_); +} + +absl::StatusOr> StructValue::GetField( + ValueFactory& value_factory, FieldId field) const { + return absl::visit(GetFieldVisitor{*this, value_factory}, field.data_); +} + +absl::StatusOr StructValue::HasField(FieldId field) const { + return absl::visit(HasFieldVisitor{*this}, field.data_); +} + +} // namespace cel diff --git a/base/values/struct_value.h b/base/values/struct_value.h new file mode 100644 index 000000000..8afc3c04e --- /dev/null +++ b/base/values/struct_value.h @@ -0,0 +1,145 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/types/struct_type.h" +#include "base/value.h" +#include "internal/rtti.h" + +namespace cel { + +class ValueFactory; + +// StructValue represents an instance of cel::StructType. +class StructValue : public Value { + public: + using FieldId = StructType::FieldId; + + static absl::StatusOr> New( + const Persistent& struct_type, + ValueFactory& value_factory); + + Persistent type() const final { return type_; } + + Kind kind() const final { return Kind::kStruct; } + + absl::Status SetField(FieldId field, const Persistent& value); + + absl::StatusOr> GetField(ValueFactory& value_factory, + FieldId field) const; + + absl::StatusOr HasField(FieldId field) const; + + protected: + explicit StructValue(const Persistent& type) : type_(type) { + ABSL_ASSERT(type_); + } + + virtual absl::Status SetFieldByName(absl::string_view name, + const Persistent& value) = 0; + + virtual absl::Status SetFieldByNumber( + int64_t number, const Persistent& value) = 0; + + virtual absl::StatusOr> GetFieldByName( + ValueFactory& value_factory, absl::string_view name) const = 0; + + virtual absl::StatusOr> GetFieldByNumber( + ValueFactory& value_factory, int64_t number) const = 0; + + virtual absl::StatusOr HasFieldByName(absl::string_view name) const = 0; + + virtual absl::StatusOr HasFieldByNumber(int64_t number) const = 0; + + private: + struct SetFieldVisitor; + struct GetFieldVisitor; + struct HasFieldVisitor; + + friend struct SetFieldVisitor; + friend struct GetFieldVisitor; + friend struct HasFieldVisitor; + friend internal::TypeInfo base_internal::GetStructValueTypeId( + const StructValue& struct_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kStruct; } + + StructValue(const StructValue&) = delete; + StructValue(StructValue&&) = delete; + + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_STRUCT_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + Persistent type_; +}; + +// CEL_DECLARE_STRUCT_VALUE declares `struct_value` as an struct value. It must +// be part of the class definition of `struct_value`. +// +// class MyStructValue : public cel::StructValue { +// ... +// private: +// CEL_DECLARE_STRUCT_VALUE(MyStructValue); +// }; +#define CEL_DECLARE_STRUCT_VALUE(struct_value) \ + CEL_INTERNAL_DECLARE_VALUE(Struct, struct_value) + +// CEL_IMPLEMENT_STRUCT_VALUE implements `struct_value` as an struct +// value. It must be called after the class definition of `struct_value`. +// +// class MyStructValue : public cel::StructValue { +// ... +// private: +// CEL_DECLARE_STRUCT_VALUE(MyStructValue); +// }; +// +// CEL_IMPLEMENT_STRUCT_VALUE(MyStructValue); +#define CEL_IMPLEMENT_STRUCT_VALUE(struct_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Struct, struct_value) + +namespace base_internal { + +inline internal::TypeInfo GetStructValueTypeId( + const StructValue& struct_value) { + return struct_value.TypeId(); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_H_ diff --git a/base/values/timestamp_value.cc b/base/values/timestamp_value.cc new file mode 100644 index 000000000..2cc2079c5 --- /dev/null +++ b/base/values/timestamp_value.cc @@ -0,0 +1,58 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/timestamp_value.h" + +#include +#include + +#include "base/types/timestamp_type.h" +#include "internal/casts.h" +#include "internal/time.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent TimestampValue::type() const { + return PersistentHandleFactory::MakeUnmanaged< + const TimestampType>(TimestampType::Get()); +} + +std::string TimestampValue::DebugString() const { + return internal::FormatTimestamp(value()).value(); +} + +void TimestampValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(TimestampValue, *this, address); +} + +void TimestampValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(TimestampValue, *this, address); +} + +bool TimestampValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void TimestampValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/timestamp_value.h b/base/values/timestamp_value.h new file mode 100644 index 000000000..9e22c1a7c --- /dev/null +++ b/base/values/timestamp_value.h @@ -0,0 +1,74 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_TIMESTAMP_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_TIMESTAMP_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "absl/time/time.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory; + +class TimestampValue final : public Value, + public base_internal::ResourceInlined { + public: + static Persistent UnixEpoch( + ValueFactory& value_factory); + + Persistent type() const override; + + Kind kind() const override { return Kind::kTimestamp; } + + std::string DebugString() const override; + + constexpr absl::Time value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { + return value.kind() == Kind::kTimestamp; + } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit TimestampValue(absl::Time value) : value_(value) {} + + TimestampValue() = delete; + + TimestampValue(const TimestampValue&) = default; + TimestampValue(TimestampValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + absl::Time value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_TIMESTAMP_VALUE_H_ diff --git a/base/values/type_value.cc b/base/values/type_value.cc new file mode 100644 index 000000000..748738c7a --- /dev/null +++ b/base/values/type_value.cc @@ -0,0 +1,55 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/type_value.h" + +#include +#include + +#include "base/types/type_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent TypeValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + TypeType::Get()); +} + +std::string TypeValue::DebugString() const { return value()->DebugString(); } + +void TypeValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(TypeValue, *this, address); +} + +void TypeValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(TypeValue, *this, address); +} + +bool TypeValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void TypeValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/type_value.h b/base/values/type_value.h new file mode 100644 index 000000000..0635a7988 --- /dev/null +++ b/base/values/type_value.h @@ -0,0 +1,67 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ + +#include +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +// TypeValue represents an instance of cel::Type. +class TypeValue final : public Value, base_internal::ResourceInlined { + public: + Persistent type() const override; + + Kind kind() const override { return Kind::kType; } + + std::string DebugString() const override; + + Persistent value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kType; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit TypeValue(Persistent type) : value_(std::move(type)) {} + + TypeValue() = delete; + + TypeValue(const TypeValue&) = default; + TypeValue(TypeValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + Persistent value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ diff --git a/base/values/uint_value.cc b/base/values/uint_value.cc new file mode 100644 index 000000000..a8a0d143f --- /dev/null +++ b/base/values/uint_value.cc @@ -0,0 +1,58 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/uint_value.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "base/types/uint_type.h" +#include "internal/casts.h" + +namespace cel { + +namespace { + +using base_internal::PersistentHandleFactory; + +} + +Persistent UintValue::type() const { + return PersistentHandleFactory::MakeUnmanaged( + UintType::Get()); +} + +std::string UintValue::DebugString() const { + return absl::StrCat(value(), "u"); +} + +void UintValue::CopyTo(Value& address) const { + CEL_INTERNAL_VALUE_COPY_TO(UintValue, *this, address); +} + +void UintValue::MoveTo(Value& address) { + CEL_INTERNAL_VALUE_MOVE_TO(UintValue, *this, address); +} + +bool UintValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void UintValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +} // namespace cel diff --git a/base/values/uint_value.h b/base/values/uint_value.h new file mode 100644 index 000000000..d76c44d06 --- /dev/null +++ b/base/values/uint_value.h @@ -0,0 +1,66 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_UINT_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_UINT_VALUE_H_ + +#include +#include + +#include "absl/hash/hash.h" +#include "base/kind.h" +#include "base/type.h" +#include "base/value.h" + +namespace cel { + +class UintValue final : public Value, public base_internal::ResourceInlined { + public: + Persistent type() const override; + + Kind kind() const override { return Kind::kUint; } + + std::string DebugString() const override; + + constexpr uint64_t value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kUint; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit UintValue(uint64_t value) : value_(value) {} + + UintValue() = delete; + + UintValue(const UintValue&) = default; + UintValue(UintValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + uint64_t value_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_UINT_VALUE_H_ From 3fdd3d037c311176771a118afe7d64c249b32260 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 9 Jun 2022 15:14:03 +0000 Subject: [PATCH 003/303] Internal change PiperOrigin-RevId: 453928706 --- base/BUILD | 58 ++++------- base/types/BUILD | 241 ------------------------------------------ base/values/BUILD | 259 ---------------------------------------------- 3 files changed, 20 insertions(+), 538 deletions(-) delete mode 100644 base/types/BUILD delete mode 100644 base/values/BUILD diff --git a/base/BUILD b/base/BUILD index 3a1686404..85d4bdda6 100644 --- a/base/BUILD +++ b/base/BUILD @@ -103,16 +103,19 @@ cc_library( name = "type", srcs = [ "type.cc", - ], + ] + glob(["types/*.cc"]), hdrs = [ "type.h", - ], + ] + glob(["types/*.h"]), deps = [ ":handle", ":kind", ":memory_manager", "//base/internal:type", + "//internal:no_destructor", + "//internal:rtti", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", ], @@ -165,23 +168,7 @@ cc_library( deps = [ ":handle", ":memory_manager", - "//base/types:any", - "//base/types:bool", - "//base/types:bytes", - "//base/types:double", - "//base/types:duration", - "//base/types:dyn", - "//base/types:enum", - "//base/types:error", - "//base/types:int", - "//base/types:list", - "//base/types:map", - "//base/types:null", - "//base/types:string", - "//base/types:struct", - "//base/types:timestamp", - "//base/types:type", - "//base/types:uint", + ":type", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -203,8 +190,6 @@ cc_test( ":type_manager", ":value", "//base/internal:memory_manager_testing", - "//base/values:enum", - "//base/values:struct", "//internal:testing", "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", @@ -216,17 +201,29 @@ cc_library( name = "value", srcs = [ "value.cc", - ], + ] + glob(["values/*.cc"]), hdrs = [ "value.h", - ], + ] + glob(["values/*.h"]), deps = [ ":handle", ":kind", ":memory_manager", ":type", "//base/internal:value", + "//internal:casts", + "//internal:no_destructor", + "//internal:rtti", + "//internal:strings", + "//internal:time", + "//internal:utf8", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], @@ -241,21 +238,6 @@ cc_library( ":memory_manager", ":type_manager", ":value", - "//base/values:bool", - "//base/values:bytes", - "//base/values:double", - "//base/values:duration", - "//base/values:enum", - "//base/values:error", - "//base/values:int", - "//base/values:list", - "//base/values:map", - "//base/values:null", - "//base/values:string", - "//base/values:struct", - "//base/values:timestamp", - "//base/values:type", - "//base/values:uint", "//internal:status_macros", "//internal:time", "//internal:utf8", diff --git a/base/types/BUILD b/base/types/BUILD deleted file mode 100644 index c82f9af25..000000000 --- a/base/types/BUILD +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # Under active development, not yet being released. - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) - -cc_library( - name = "any", - srcs = ["any_type.cc"], - hdrs = ["any_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "bool", - srcs = ["bool_type.cc"], - hdrs = ["bool_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "bytes", - srcs = ["bytes_type.cc"], - hdrs = ["bytes_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "double", - srcs = ["double_type.cc"], - hdrs = ["double_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "duration", - srcs = ["duration_type.cc"], - hdrs = ["duration_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "dyn", - srcs = ["dyn_type.cc"], - hdrs = ["dyn_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "enum", - srcs = ["enum_type.cc"], - hdrs = ["enum_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:rtti", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:variant", - ], -) - -cc_library( - name = "error", - srcs = ["error_type.cc"], - hdrs = ["error_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "int", - srcs = ["int_type.cc"], - hdrs = ["int_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "list", - srcs = ["list_type.cc"], - hdrs = ["list_type.h"], - deps = [ - "//base:kind", - "//base:type", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "map", - srcs = ["map_type.cc"], - hdrs = ["map_type.h"], - deps = [ - "//base:kind", - "//base:type", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "null", - srcs = ["null_type.cc"], - hdrs = ["null_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "string", - srcs = ["string_type.cc"], - hdrs = ["string_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "struct", - srcs = ["struct_type.cc"], - hdrs = ["struct_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:rtti", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:variant", - ], -) - -cc_library( - name = "timestamp", - srcs = ["timestamp_type.cc"], - hdrs = ["timestamp_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "type", - srcs = ["type_type.cc"], - hdrs = ["type_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "uint", - srcs = ["uint_type.cc"], - hdrs = ["uint_type.h"], - deps = [ - "//base:kind", - "//base:type", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) diff --git a/base/values/BUILD b/base/values/BUILD deleted file mode 100644 index f693cedfc..000000000 --- a/base/values/BUILD +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # Under active development, not yet being released. - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) - -cc_library( - name = "bool", - srcs = ["bool_value.cc"], - hdrs = ["bool_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:bool", - "//internal:casts", - "@com_google_absl//absl/hash", - ], -) - -cc_library( - name = "bytes", - srcs = ["bytes_value.cc"], - hdrs = ["bytes_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:bytes", - "//internal:casts", - "//internal:strings", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - ], -) - -cc_library( - name = "double", - srcs = ["double_value.cc"], - hdrs = ["double_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:double", - "//internal:casts", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "duration", - srcs = ["duration_value.cc"], - hdrs = ["duration_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:duration", - "//internal:casts", - "//internal:time", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "enum", - srcs = ["enum_value.cc"], - hdrs = ["enum_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:enum", - "//internal:casts", - "//internal:rtti", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "error", - srcs = ["error_value.cc"], - hdrs = ["error_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:error", - "//internal:casts", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "int", - srcs = ["int_value.cc"], - hdrs = ["int_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:int", - "//internal:casts", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "list", - srcs = ["list_value.cc"], - hdrs = ["list_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:list", - "//internal:rtti", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "map", - srcs = ["map_value.cc"], - hdrs = ["map_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:map", - "//internal:rtti", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "null", - srcs = ["null_value.cc"], - hdrs = ["null_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:null", - "//internal:no_destructor", - "@com_google_absl//absl/hash", - ], -) - -cc_library( - name = "string", - srcs = ["string_value.cc"], - hdrs = ["string_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:string", - "//internal:casts", - "//internal:strings", - "//internal:utf8", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - ], -) - -cc_library( - name = "struct", - srcs = ["struct_value.cc"], - hdrs = ["struct_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:struct", - "//internal:casts", - "//internal:rtti", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "timestamp", - srcs = ["timestamp_value.cc"], - hdrs = ["timestamp_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:timestamp", - "//internal:casts", - "//internal:time", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "type", - srcs = ["type_value.cc"], - hdrs = ["type_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:type", - "//internal:casts", - "@com_google_absl//absl/hash", - ], -) - -cc_library( - name = "uint", - srcs = ["uint_value.cc"], - hdrs = ["uint_value.h"], - deps = [ - "//base:kind", - "//base:type", - "//base:value", - "//base/types:uint", - "//internal:casts", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - ], -) From a063b6edbb92ca65afdccffc6c9f06554db1c5ba Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 9 Jun 2022 18:32:00 +0000 Subject: [PATCH 004/303] Internal change PiperOrigin-RevId: 453973282 --- base/BUILD | 1 + base/kind.cc | 4 ---- base/kind.h | 5 ++--- base/kind_test.cc | 2 -- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/base/BUILD b/base/BUILD index 85d4bdda6..127ac0481 100644 --- a/base/BUILD +++ b/base/BUILD @@ -34,6 +34,7 @@ cc_library( srcs = ["kind.cc"], hdrs = ["kind.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) diff --git a/base/kind.cc b/base/kind.cc index f1c207e4b..8eccd110f 100644 --- a/base/kind.cc +++ b/base/kind.cc @@ -26,8 +26,6 @@ absl::string_view KindToString(Kind kind) { return "any"; case Kind::kType: return "type"; - case Kind::kTypeParam: - return "type_param"; case Kind::kBool: return "bool"; case Kind::kInt: @@ -52,8 +50,6 @@ absl::string_view KindToString(Kind kind) { return "map"; case Kind::kStruct: return "struct"; - case Kind::kOpaque: - return "opaque"; default: return "*error*"; } diff --git a/base/kind.h b/base/kind.h index cb294075e..e7f6169a0 100644 --- a/base/kind.h +++ b/base/kind.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ #define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ +#include "absl/base/attributes.h" #include "absl/strings/string_view.h" namespace cel { @@ -25,7 +26,6 @@ enum class Kind { kDyn, kAny, kType, - kTypeParam, kBool, kInt, kUint, @@ -38,10 +38,9 @@ enum class Kind { kList, kMap, kStruct, - kOpaque, }; -absl::string_view KindToString(Kind kind); +ABSL_ATTRIBUTE_PURE_FUNCTION absl::string_view KindToString(Kind kind); } // namespace cel diff --git a/base/kind_test.cc b/base/kind_test.cc index 4069f931d..22050cf9c 100644 --- a/base/kind_test.cc +++ b/base/kind_test.cc @@ -27,7 +27,6 @@ TEST(Kind, ToString) { EXPECT_EQ(KindToString(Kind::kDyn), "dyn"); EXPECT_EQ(KindToString(Kind::kAny), "any"); EXPECT_EQ(KindToString(Kind::kType), "type"); - EXPECT_EQ(KindToString(Kind::kTypeParam), "type_param"); EXPECT_EQ(KindToString(Kind::kBool), "bool"); EXPECT_EQ(KindToString(Kind::kInt), "int"); EXPECT_EQ(KindToString(Kind::kUint), "uint"); @@ -40,7 +39,6 @@ TEST(Kind, ToString) { EXPECT_EQ(KindToString(Kind::kList), "list"); EXPECT_EQ(KindToString(Kind::kMap), "map"); EXPECT_EQ(KindToString(Kind::kStruct), "struct"); - EXPECT_EQ(KindToString(Kind::kOpaque), "opaque"); EXPECT_EQ(KindToString(static_cast(std::numeric_limits::max())), "*error*"); } From a81d608b0d318550eae06f4d87a565db58a7b11c Mon Sep 17 00:00:00 2001 From: kuat Date: Mon, 13 Jun 2022 14:06:38 +0000 Subject: [PATCH 005/303] Internal change. PiperOrigin-RevId: 454599935 --- eval/compiler/flat_expr_builder_test.cc | 2 +- eval/eval/create_struct_step_test.cc | 2 +- eval/eval/function_step_test.cc | 2 +- eval/eval/select_step_test.cc | 2 +- eval/public/builtin_func_test.cc | 2 +- eval/public/structs/cel_proto_wrap_util.cc | 6 +++--- eval/public/structs/field_access_impl_test.cc | 12 ++++++------ eval/public/structs/proto_message_type_adapter.cc | 8 ++++---- .../structs/protobuf_descriptor_type_provider.cc | 2 +- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index f12486737..8797d8608 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1895,7 +1895,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { std::pair CreateTestMessage( const google::protobuf::DescriptorPool& descriptor_pool, google::protobuf::MessageFactory& message_factory, absl::string_view name) { - const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(name.data()); + const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(std::string(name)); const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); google::protobuf::Message* message = message_prototype->New(); const google::protobuf::Reflection* refl = message->GetReflection(); diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 85efc2d2f..a3190acfd 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -64,7 +64,7 @@ absl::StatusOr RunExpression(absl::string_view field, create_struct->set_message_name("google.api.expr.runtime.TestMessage"); auto entry = create_struct->add_entries(); - entry->set_field_key(field.data()); + entry->set_field_key(std::string(field)); auto adapter = type_registry.FindTypeAdapter(create_struct->message_name()); if (!adapter.has_value() || adapter->mutation_apis() == nullptr) { diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 223d6eb83..3fcec6dc1 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -56,7 +56,7 @@ class ConstFunction : public CelFunction { static Expr::Call MakeCall(absl::string_view name) { Expr::Call call; - call.set_function(name.data()); + call.set_function(std::string(name)); call.clear_target(); return call; } diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index efe202cc8..8cf94d609 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -77,7 +77,7 @@ absl::StatusOr RunExpression(const CelValue target, Expr dummy_expr; auto select = dummy_expr.mutable_select_expr(); - select->set_field(field.data()); + select->set_field(std::string(field)); select->set_test_only(test); Expr* expr0 = select->mutable_operand(); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index c30633004..d0065e788 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -68,7 +68,7 @@ class BuiltinsTest : public ::testing::Test { Expr expr; SourceInfo source_info; auto call = expr.mutable_call_expr(); - call->set_function(operation.data()); + call->set_function(std::string(operation)); if (target.has_value()) { std::string param_name = "target"; diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index 8ff817b7d..8d03bee87 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -432,7 +432,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, BytesValue* w if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(view_val.value().data()); + wrapper->set_value(std::string(view_val.value())); return wrapper; } @@ -490,7 +490,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, StringValue* if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(view_val.value().data()); + wrapper->set_value(std::string(view_val.value())); return wrapper; } @@ -627,7 +627,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json) case CelValue::Type::kString: { CelValue::StringHolder val; if (value.GetValue(&val)) { - json->set_string_value(val.value().data()); + json->set_string_value(std::string(val.value())); return json; } } break; diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc index afda4d93b..d5f259127 100644 --- a/eval/public/structs/field_access_impl_test.cc +++ b/eval/public/structs/field_access_impl_test.cc @@ -184,14 +184,14 @@ class SingleFieldTest : public testing::TestWithParam { TEST_P(SingleFieldTest, Getter) { TestAllTypes test_message; ASSERT_TRUE( - google::protobuf::TextFormat::ParseFromString(message_textproto().data(), &test_message)); + google::protobuf::TextFormat::ParseFromString(std::string(message_textproto()), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromSingleField( &test_message, - test_message.GetDescriptor()->FindFieldByName(field_name().data()), + test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), ProtoWrapperTypeOptions::kUnsetProtoDefault, &CelProtoWrapper::InternalWrapMessage, &arena)); @@ -204,7 +204,7 @@ TEST_P(SingleFieldTest, Setter) { google::protobuf::Arena arena; ASSERT_OK(SetValueToSingleField( - to_set, test_message.GetDescriptor()->FindFieldByName(field_name().data()), + to_set, test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); @@ -361,14 +361,14 @@ class RepeatedFieldTest : public testing::TestWithParam { TEST_P(RepeatedFieldTest, GetFirstElem) { TestAllTypes test_message; ASSERT_TRUE( - google::protobuf::TextFormat::ParseFromString(message_textproto().data(), &test_message)); + google::protobuf::TextFormat::ParseFromString(std::string(message_textproto()), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromRepeatedField( &test_message, - test_message.GetDescriptor()->FindFieldByName(field_name().data()), 0, + test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), 0, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); @@ -380,7 +380,7 @@ TEST_P(RepeatedFieldTest, AppendElem) { google::protobuf::Arena arena; ASSERT_OK(AddValueToRepeatedField( - to_add, test_message.GetDescriptor()->FindFieldByName(field_name().data()), + to_add, test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 1a0eda8f2..ccd2cded8 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -86,7 +86,7 @@ absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, absl::string_view field_name) { ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); - const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name.data()); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(std::string(field_name)); if (field_desc == nullptr) { return absl::NotFoundError(absl::StrCat("no_such_field : ", field_name)); @@ -118,7 +118,7 @@ absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) { ABSL_ASSERT(descriptor == message->GetDescriptor()); - const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name.data()); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(std::string(field_name)); if (field_desc == nullptr) { return CreateNoSuchFieldError(memory_manager, field_name); @@ -249,7 +249,7 @@ ProtoMessageTypeAdapter::NewInstance(cel::MemoryManager& memory_manager) const { } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { - return descriptor_->FindFieldByName(field_name.data()) != nullptr; + return descriptor_->FindFieldByName(std::string(field_name)) != nullptr; } absl::StatusOr ProtoMessageTypeAdapter::HasField( @@ -282,7 +282,7 @@ absl::Status ProtoMessageTypeAdapter::SetField( UnwrapMessage(instance, "SetField")); const google::protobuf::FieldDescriptor* field_descriptor = - descriptor_->FindFieldByName(field_name.data()); + descriptor_->FindFieldByName(std::string(field_name)); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(field_descriptor != nullptr, field_name, "not found")); diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 6467c7835..c770c04e7 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -48,7 +48,7 @@ absl::optional ProtobufDescriptorProvider::ProvideLegacyType( std::unique_ptr ProtobufDescriptorProvider::GetType( absl::string_view name) const { const google::protobuf::Descriptor* descriptor = - descriptor_pool_->FindMessageTypeByName(name.data()); + descriptor_pool_->FindMessageTypeByName(std::string(name)); if (descriptor == nullptr) { return nullptr; } From d97681711951c41339f3ef29e884ae8349a1c81f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 15 Jun 2022 20:19:30 +0000 Subject: [PATCH 006/303] Make the AST types more convenient to use. PiperOrigin-RevId: 455204424 --- base/BUILD | 1 + base/ast.cc | 169 ++++++++++++++ base/ast.h | 484 +++++++++++++++++++++++++++++++++++++-- base/ast_test.cc | 215 +++++++++++++---- base/ast_utility.cc | 23 +- base/ast_utility_test.cc | 348 +++++++++++----------------- 6 files changed, 964 insertions(+), 276 deletions(-) create mode 100644 base/ast.cc diff --git a/base/BUILD b/base/BUILD index 127ac0481..f6b0157dd 100644 --- a/base/BUILD +++ b/base/BUILD @@ -279,6 +279,7 @@ cc_test( cc_library( name = "ast", + srcs = ["ast.cc"], hdrs = [ "ast.h", ], diff --git a/base/ast.cc b/base/ast.cc new file mode 100644 index 000000000..7ae327bdd --- /dev/null +++ b/base/ast.cc @@ -0,0 +1,169 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/ast.h" + +#include + +namespace cel::ast::internal { + +namespace { +const Expr& default_expr() { + static Expr* expr = new Expr(); + return *expr; +} +} // namespace + +const Expr& Select::operand() const { + if (operand_ != nullptr) { + return *operand_; + } + return default_expr(); +} + +bool Select::operator==(const Select& other) const { + return operand() == other.operand() && field_ == other.field_ && + test_only_ == other.test_only_; +} + +const Expr& Call::target() const { + if (target_ != nullptr) { + return *target_; + } + return default_expr(); +} + +bool Call::operator==(const Call& other) const { + return target() == other.target() && function_ == other.function_ && + args_ == other.args_; +} + +const Expr& CreateStruct::Entry::map_key() const { + auto* value = absl::get_if>(&key_kind_); + if (value != nullptr) { + if (*value != nullptr) return **value; + } + return default_expr(); +} + +const Expr& CreateStruct::Entry::value() const { + if (value_ != nullptr) { + return *value_; + } + return default_expr(); +} + +bool CreateStruct::Entry::operator==(const Entry& other) const { + return id_ == other.id_ && key_kind_ == other.key_kind_ && + value() == other.value(); +} + +const Expr& Comprehension::iter_range() const { + if (iter_range_ != nullptr) { + return *iter_range_; + } + return default_expr(); +} + +const Expr& Comprehension::accu_init() const { + if (accu_init_ != nullptr) { + return *accu_init_; + } + return default_expr(); +} + +const Expr& Comprehension::loop_condition() const { + if (loop_condition_ != nullptr) { + return *loop_condition_; + } + return default_expr(); +} + +const Expr& Comprehension::loop_step() const { + if (loop_step_ != nullptr) { + return *loop_step_; + } + return default_expr(); +} + +const Expr& Comprehension::result() const { + if (result_ != nullptr) { + return *result_; + } + return default_expr(); +} + +bool Comprehension::operator==(const Comprehension& other) const { + return iter_var_ == other.iter_var_ && iter_range() == other.iter_range() && + accu_var_ == other.accu_var_ && accu_init() == other.accu_init() && + loop_condition() == other.loop_condition() && + loop_step() == other.loop_step() && result() == other.result(); +} + +namespace { +const Type& default_type() { + static Type* type = new Type(); + return *type; +} +} // namespace + +const Type& ListType::elem_type() const { + if (elem_type_ != nullptr) { + return *elem_type_; + } + return default_type(); +} + +bool ListType::operator==(const ListType& other) const { + return elem_type() == other.elem_type(); +} + +const Type& MapType::key_type() const { + if (key_type_ != nullptr) { + return *key_type_; + } + return default_type(); +} + +const Type& MapType::value_type() const { + if (value_type_ != nullptr) { + return *value_type_; + } + return default_type(); +} + +bool MapType::operator==(const MapType& other) const { + return key_type() == other.key_type() && value_type() == other.value_type(); +} + +const Type& FunctionType::result_type() const { + if (result_type_ != nullptr) { + return *result_type_; + } + return default_type(); +} + +bool FunctionType::operator==(const FunctionType& other) const { + return result_type() == other.result_type() && arg_types_ == other.arg_types_; +} + +const Type& Type::type() const { + auto* value = absl::get_if>(&type_kind_); + if (value != nullptr) { + if (*value != nullptr) return **value; + } + return default_type(); +} + +} // namespace cel::ast::internal diff --git a/base/ast.h b/base/ast.h index a4fcc34ac..da735861d 100644 --- a/base/ast.h +++ b/base/ast.h @@ -24,7 +24,6 @@ #include #include -#include "absl/base/macros.h" #include "absl/container/flat_hash_map.h" #include "absl/time/time.h" #include "absl/types/variant.h" @@ -49,20 +48,144 @@ enum class NullValue { kNullValue = 0 }; // message that can hold any constant object representation supplied or // produced at evaluation time. // --) -using Constant = absl::variant; +using ConstantKind = absl::variant; + +class Constant { + public: + constexpr Constant() {} + explicit Constant(ConstantKind constant_kind) + : constant_kind_(std::move(constant_kind)) {} + + void set_constant_kind(ConstantKind constant_kind) { + constant_kind_ = std::move(constant_kind); + } + + const ConstantKind& constant_kind() const { return constant_kind_; } + + ConstantKind& mutable_constant_kind() { return constant_kind_; } + + bool has_null_value() const { + return absl::holds_alternative(constant_kind_); + } + + NullValue null_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + return NullValue::kNullValue; + } + + bool has_bool_value() const { + return absl::holds_alternative(constant_kind_); + } + + bool bool_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + return false; + } + + bool has_int64_value() const { + return absl::holds_alternative(constant_kind_); + } + + int64_t int64_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + return 0; + } + + bool has_uint64_value() const { + return absl::holds_alternative(constant_kind_); + } + + uint64_t uint64_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + return 0; + } + + bool has_double_value() const { + return absl::holds_alternative(constant_kind_); + } + + double double_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + return 0; + } + + bool has_string_value() const { + return absl::holds_alternative(constant_kind_); + } + + const std::string& string_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + static std::string* default_string_value_ = new std::string(""); + return *default_string_value_; + } + + bool has_duration_value() const { + return absl::holds_alternative(constant_kind_); + } + + const absl::Duration& duration_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + static absl::Duration default_duration_; + return default_duration_; + } + + bool has_time_value() const { + return absl::holds_alternative(constant_kind_); + } + + const absl::Time& time_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return *value; + } + static absl::Time default_time_; + return default_time_; + } + + bool operator==(const Constant& other) const { + return constant_kind_ == other.constant_kind_; + } + + private: + ConstantKind constant_kind_; +}; class Expr; // An identifier expression. e.g. `request`. class Ident { public: + Ident() {} explicit Ident(std::string name) : name_(std::move(name)) {} void set_name(std::string name) { name_ = std::move(name); } const std::string& name() const { return name_; } + bool operator==(const Ident& other) const { return name_ == other.name_; } + private: // Required. Holds a single, unqualified identifier, possibly preceded by a // '.'. @@ -89,7 +212,9 @@ class Select { void set_test_only(bool test_only) { test_only_ = test_only; } - const Expr* operand() const { return operand_.get(); } + bool has_operand() const { return operand_ != nullptr; } + + const Expr& operand() const; Expr& mutable_operand() { if (operand_ == nullptr) { @@ -102,6 +227,8 @@ class Select { bool test_only() const { return test_only_; } + bool operator==(const Select& other) const; + private: // Required. The target of the selection expression. // @@ -116,7 +243,7 @@ class Select { // Whether the select is to be interpreted as a field presence test. // // This results from the macro `has(request.auth)`. - bool test_only_; + bool test_only_ = false; }; // A call expression, including calls to predefined functions and operators. @@ -138,7 +265,9 @@ class Call { void set_args(std::vector args) { args_ = std::move(args); } - const Expr* target() const { return target_.get(); } + bool has_target() const { return target_ != nullptr; } + + const Expr& target() const; Expr& mutable_target() { if (target_ == nullptr) { @@ -153,6 +282,8 @@ class Call { std::vector& mutable_args() { return args_; } + bool operator==(const Call& other) const; + private: // The target of an method call-style expression. For example, `x` in // `x.f()`. @@ -185,6 +316,10 @@ class CreateList { std::vector& mutable_elements() { return elements_; } + bool operator==(const CreateList& other) const { + return elements_ == other.elements_; + } + private: // The elements part of the list. std::vector elements_; @@ -217,7 +352,28 @@ class CreateStruct { KeyKind& mutable_key_kind() { return key_kind_; } - const Expr* value() const { return value_.get(); } + bool has_field_key() const { + return absl::holds_alternative(key_kind_); + } + + bool has_map_key() const { + return absl::holds_alternative>(key_kind_); + } + + const std::string& field_key() const { + auto* value = absl::get_if(&key_kind_); + if (value != nullptr) { + return *value; + } + static const std::string* default_field_key = new std::string; + return *default_field_key; + } + + const Expr& map_key() const; + + bool has_value() const { return value_ != nullptr; } + + const Expr& value() const; Expr& mutable_value() { if (value_ == nullptr) { @@ -226,6 +382,8 @@ class CreateStruct { return *value_; } + bool operator==(const Entry& other) const; + private: // Required. An id assigned to this node by the parser which is unique // in a given expression tree. This is used to associate type @@ -253,6 +411,10 @@ class CreateStruct { std::vector& mutable_entries() { return entries_; } + bool operator==(const CreateStruct& other) const { + return message_name_ == other.message_name_ && entries_ == other.entries_; + } + private: // The type name of the message to be created, empty when creating map // literals. @@ -321,6 +483,16 @@ class Comprehension { loop_step_(std::move(loop_step)), result_(std::move(result)) {} + bool has_iter_range() const { return iter_range_ != nullptr; } + + bool has_accu_init() const { return accu_init_ != nullptr; } + + bool has_loop_condition() const { return loop_condition_ != nullptr; } + + bool has_loop_step() const { return loop_step_ != nullptr; } + + bool has_result() const { return result_ != nullptr; } + void set_iter_var(std::string iter_var) { iter_var_ = std::move(iter_var); } void set_iter_range(std::unique_ptr iter_range) { @@ -345,7 +517,7 @@ class Comprehension { const std::string& iter_var() const { return iter_var_; } - const Expr* iter_range() const { return iter_range_.get(); } + const Expr& iter_range() const; Expr& mutable_iter_range() { if (iter_range_ == nullptr) { @@ -356,7 +528,7 @@ class Comprehension { const std::string& accu_var() const { return accu_var_; } - const Expr* accu_init() const { return accu_init_.get(); } + const Expr& accu_init() const; Expr& mutable_accu_init() { if (accu_init_ == nullptr) { @@ -365,7 +537,7 @@ class Comprehension { return *accu_init_; } - const Expr* loop_condition() const { return loop_condition_.get(); } + const Expr& loop_condition() const; Expr& mutable_loop_condition() { if (loop_condition_ == nullptr) { @@ -374,7 +546,7 @@ class Comprehension { return *loop_condition_; } - const Expr* loop_step() const { return loop_step_.get(); } + const Expr& loop_step() const; Expr& mutable_loop_step() { if (loop_step_ == nullptr) { @@ -383,7 +555,7 @@ class Comprehension { return *loop_step_; } - const Expr* result() const { return result_.get(); } + const Expr& result() const; Expr& mutable_result() { if (result_ == nullptr) { @@ -392,6 +564,8 @@ class Comprehension { return *result_; } + bool operator==(const Comprehension& other) const; + private: // The name of the iteration variable. std::string iter_var_; @@ -461,6 +635,101 @@ class Expr { ExprKind& mutable_expr_kind() { return expr_kind_; } + bool has_const_expr() const { + return absl::holds_alternative(expr_kind_); + } + + bool has_ident_expr() const { + return absl::holds_alternative(expr_kind_); + } + + bool has_select_expr() const { + return absl::holds_alternative(&expr_kind_); + if (value != nullptr) { + return *value; + } + static const Select* default_select = new Select; + return *default_select; + } + + const Call& call_expr() const { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + static const Call* default_call = new Call; + return *default_call; + } + + const CreateList& list_expr() const { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + static const CreateList* default_create_list = new CreateList; + return *default_create_list; + } + + const CreateStruct& struct_expr() const { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + static const CreateStruct* default_create_struct = new CreateStruct; + return *default_create_struct; + } + + const Comprehension& comprehension_expr() const { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + static const Comprehension* default_comprehension = new Comprehension; + return *default_comprehension; + } + + bool operator==(const Expr& other) const { + return id_ == other.id_ && expr_kind_ == other.expr_kind_; + } + private: // Required. An id assigned to this node by the parser which is unique in a // given expression tree. This is used to associate type information and other @@ -651,7 +920,9 @@ class ListType { elem_type_ = std::move(elem_type); } - const Type* elem_type() const { return elem_type_.get(); } + bool has_elem_type() const { return elem_type_ != nullptr; } + + const Type& elem_type() const; Type& mutable_elem_type() { if (elem_type_ == nullptr) { @@ -660,6 +931,8 @@ class ListType { return *elem_type_; } + bool operator==(const ListType& other) const; + private: std::unique_ptr elem_type_; }; @@ -679,9 +952,15 @@ class MapType { value_type_ = std::move(value_type); } - const Type* key_type() const { return key_type_.get(); } + bool has_key_type() const { return key_type_ != nullptr; } + + bool has_value_type() const { return value_type_ != nullptr; } - const Type* value_type() const { return value_type_.get(); } + const Type& key_type() const; + + const Type& value_type() const; + + bool operator==(const MapType& other) const; Type& mutable_key_type() { if (key_type_ == nullptr) { @@ -726,7 +1005,9 @@ class FunctionType { arg_types_ = std::move(arg_types); } - const Type* result_type() const { return result_type_.get(); } + bool has_result_type() const { return result_type_ != nullptr; } + + const Type& result_type() const; Type& mutable_result_type() { if (result_type_ == nullptr) { @@ -739,6 +1020,8 @@ class FunctionType { std::vector& mutable_arg_types() { return arg_types_; } + bool operator==(const FunctionType& other) const; + private: // Result type of the function. std::unique_ptr result_type_; @@ -752,6 +1035,7 @@ class FunctionType { // TODO(issues/5): decide on final naming for this. class AbstractType { public: + AbstractType() {} AbstractType(std::string name, std::vector parameter_types) : name_(std::move(name)), parameter_types_(std::move(parameter_types)) {} @@ -767,6 +1051,10 @@ class AbstractType { std::vector& mutable_parameter_types() { return parameter_types_; } + bool operator==(const AbstractType& other) const { + return name_ == other.name_ && parameter_types_ == other.parameter_types_; + } + private: // The fully qualified name of this abstract type. std::string name_; @@ -786,6 +1074,10 @@ class PrimitiveTypeWrapper { PrimitiveType& mutable_type() { return type_; } + bool operator==(const PrimitiveTypeWrapper& other) const { + return type_ == other.type_; + } + private: PrimitiveType type_; }; @@ -796,12 +1088,17 @@ class PrimitiveTypeWrapper { // example, `google.plus.Profile`. class MessageType { public: + MessageType() {} explicit MessageType(std::string type) : type_(std::move(type)) {} void set_type(std::string type) { type_ = std::move(type); } const std::string& type() const { return type_; } + bool operator==(const MessageType& other) const { + return type_ == other.type_; + } + private: std::string type_; }; @@ -813,12 +1110,15 @@ class MessageType { // named `E`. class ParamType { public: + ParamType() {} explicit ParamType(std::string type) : type_(std::move(type)) {} void set_type(std::string type) { type_ = std::move(type); } const std::string& type() const { return type_; } + bool operator==(const ParamType& other) const { return type_ == other.type_; } + private: std::string type_; }; @@ -855,6 +1155,158 @@ class Type { TypeKind& mutable_type_kind() { return type_kind_; } + bool has_dyn() const { + return absl::holds_alternative(type_kind_); + } + + bool has_null() const { + return absl::holds_alternative(type_kind_); + } + + bool has_primitive() const { + return absl::holds_alternative(type_kind_); + } + + bool has_wrapper() const { + return absl::holds_alternative(type_kind_); + } + + bool has_well_known() const { + return absl::holds_alternative(type_kind_); + } + + bool has_list_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_map_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_function() const { + return absl::holds_alternative(type_kind_); + } + + bool has_message_type() const { + return absl::holds_alternative(type_kind_); + } + + bool has_type_param() const { + return absl::holds_alternative(type_kind_); + } + + bool has_type() const { + return absl::holds_alternative>(type_kind_); + } + + bool has_error() const { + return absl::holds_alternative(type_kind_); + } + + bool has_abstract_type() const { + return absl::holds_alternative(type_kind_); + } + + NullValue null() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return NullValue::kNullValue; + } + + PrimitiveType primitive() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return PrimitiveType::kPrimitiveTypeUnspecified; + } + + PrimitiveType wrapper() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return value->type(); + } + return PrimitiveType::kPrimitiveTypeUnspecified; + } + + WellKnownType well_known() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return WellKnownType::kWellKnownTypeUnspecified; + } + + const ListType& list_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const ListType* default_list_type = new ListType(); + return *default_list_type; + } + + const MapType& map_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const MapType* default_map_type = new MapType(); + return *default_map_type; + } + + const FunctionType& function() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const FunctionType* default_function_type = new FunctionType(); + return *default_function_type; + } + + const MessageType& message_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const MessageType* default_message_type = new MessageType(); + return *default_message_type; + } + + const ParamType& type_param() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const ParamType* default_param_type = new ParamType(); + return *default_param_type; + } + + const Type& type() const; + + ErrorType error_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + return ErrorType::kErrorTypeValue; + } + + const AbstractType& abstract_type() const { + auto* value = absl::get_if(&type_kind_); + if (value != nullptr) { + return *value; + } + static const AbstractType* default_abstract_type = new AbstractType(); + return *default_abstract_type; + } + + bool operator==(const Type& other) const { + return type_kind_ == other.type_kind_; + } + private: TypeKind type_kind_; }; diff --git a/base/ast_test.cc b/base/ast_test.cc index 8f1bf3bd7..222a1d216 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -18,7 +18,6 @@ #include #include "absl/memory/memory.h" -#include "absl/types/variant.h" #include "internal/testing.h" namespace cel { @@ -26,11 +25,23 @@ namespace ast { namespace internal { namespace { TEST(AstTest, ExprConstructionConstant) { - Expr expr(1, true); + Expr expr(1, Constant(true)); ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); const auto& constant = absl::get(expr.expr_kind()); - ASSERT_TRUE(absl::holds_alternative(constant)); - ASSERT_TRUE(absl::get(constant)); + ASSERT_TRUE(constant.has_bool_value()); + ASSERT_TRUE(constant.bool_value()); +} + +TEST(AstTest, ConstantDefaults) { + Constant constant; + EXPECT_EQ(constant.null_value(), NullValue::kNullValue); + EXPECT_EQ(constant.bool_value(), false); + EXPECT_EQ(constant.int64_value(), 0); + EXPECT_EQ(constant.uint64_value(), 0); + EXPECT_EQ(constant.double_value(), 0); + EXPECT_TRUE(constant.string_value().empty()); + EXPECT_EQ(constant.duration_value(), absl::Duration()); + EXPECT_EQ(constant.time_value(), absl::UnixEpoch()); } TEST(AstTest, ExprConstructionIdent) { @@ -43,24 +54,41 @@ TEST(AstTest, ExprConstructionSelect) { Expr expr(1, Select(std::make_unique(2, Ident("var")), "field")); ASSERT_TRUE(absl::holds_alternative(expr.expr_kind()); - ASSERT_TRUE(absl::holds_alternative(select.operand()->expr_kind())); - ASSERT_EQ(absl::get(select.operand()->expr_kind()).name(), "var"); + ASSERT_TRUE(absl::holds_alternative(select.operand().expr_kind())); + ASSERT_EQ(absl::get(select.operand().expr_kind()).name(), "var"); ASSERT_EQ(select.field(), "field"); } TEST(AstTest, SelectMutableOperand) { Select select; select.mutable_operand().set_expr_kind(Ident("var")); - ASSERT_TRUE(absl::holds_alternative(select.operand()->expr_kind())); - ASSERT_EQ(absl::get(select.operand()->expr_kind()).name(), "var"); + ASSERT_TRUE(absl::holds_alternative(select.operand().expr_kind())); + ASSERT_EQ(absl::get(select.operand().expr_kind()).name(), "var"); +} + +TEST(AstTest, SelectDefaultOperand) { + Select select; + EXPECT_EQ(select.operand(), Expr()); +} + +TEST(AstTest, SelectComparatorTestOnly) { + Select select; + select.set_test_only(true); + EXPECT_FALSE(select == Select()); +} + +TEST(AstTest, SelectComparatorField) { + Select select; + select.set_field("field"); + EXPECT_FALSE(select == Select()); } TEST(AstTest, ExprConstructionCall) { Expr expr(1, Call(std::make_unique(2, Ident("var")), "function", {})); ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); const auto& call = absl::get(expr.expr_kind()); - ASSERT_TRUE(absl::holds_alternative(call.target()->expr_kind())); - ASSERT_EQ(absl::get(call.target()->expr_kind()).name(), "var"); + ASSERT_TRUE(absl::holds_alternative(call.target().expr_kind())); + ASSERT_EQ(absl::get(call.target().expr_kind()).name(), "var"); ASSERT_EQ(call.function(), "function"); ASSERT_TRUE(call.args().empty()); } @@ -68,8 +96,28 @@ TEST(AstTest, ExprConstructionCall) { TEST(AstTest, CallMutableTarget) { Call call; call.mutable_target().set_expr_kind(Ident("var")); - ASSERT_TRUE(absl::holds_alternative(call.target()->expr_kind())); - ASSERT_EQ(absl::get(call.target()->expr_kind()).name(), "var"); + ASSERT_TRUE(absl::holds_alternative(call.target().expr_kind())); + ASSERT_EQ(absl::get(call.target().expr_kind()).name(), "var"); +} + +TEST(AstTest, CallDefaultTarget) { EXPECT_EQ(Call().target(), Expr()); } + +TEST(AstTest, CallComparatorTarget) { + Call call; + call.set_function("function"); + EXPECT_FALSE(call == Call()); +} + +TEST(AstTest, CallComparatorArgs) { + Call call; + call.mutable_args().emplace_back(Expr()); + EXPECT_FALSE(call == Call()); +} + +TEST(AstTest, CallComparatorFunction) { + Call call; + call.set_function("function"); + EXPECT_FALSE(call == Call()); } TEST(AstTest, ExprConstructionCreateList) { @@ -99,22 +147,29 @@ TEST(AstTest, ExprConstructionCreateStruct) { ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); const auto& entries = absl::get(expr.expr_kind()).entries(); ASSERT_EQ(absl::get(entries[0].key_kind()), "key1"); - ASSERT_EQ(absl::get(entries[0].value()->expr_kind()).name(), "value1"); + ASSERT_EQ(absl::get(entries[0].value().expr_kind()).name(), "value1"); ASSERT_EQ(absl::get(entries[1].key_kind()), "key2"); - ASSERT_EQ(absl::get(entries[1].value()->expr_kind()).name(), "value2"); + ASSERT_EQ(absl::get(entries[1].value().expr_kind()).name(), "value2"); ASSERT_EQ( absl::get( absl::get>(entries[2].key_kind())->expr_kind()) .name(), "key3"); - ASSERT_EQ(absl::get(entries[2].value()->expr_kind()).name(), "value3"); + ASSERT_EQ(absl::get(entries[2].value().expr_kind()).name(), "value3"); +} + +TEST(AstTest, ExprCreateStructEntryDefaults) { + CreateStruct::Entry entry; + EXPECT_TRUE(entry.field_key().empty()); + EXPECT_EQ(entry.map_key(), Expr()); + EXPECT_EQ(entry.value(), Expr()); } TEST(AstTest, CreateStructEntryMutableValue) { CreateStruct::Entry entry; entry.mutable_value().set_expr_kind(Ident("var")); - ASSERT_TRUE(absl::holds_alternative(entry.value()->expr_kind())); - ASSERT_EQ(absl::get(entry.value()->expr_kind()).name(), "var"); + ASSERT_TRUE(absl::holds_alternative(entry.value().expr_kind())); + ASSERT_EQ(absl::get(entry.value().expr_kind()).name(), "var"); } TEST(AstTest, ExprConstructionComprehension) { @@ -130,16 +185,16 @@ TEST(AstTest, ExprConstructionComprehension) { ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); auto& created_expr = absl::get(expr.expr_kind()); ASSERT_EQ(created_expr.iter_var(), "iter_var"); - ASSERT_EQ(absl::get(created_expr.iter_range()->expr_kind()).name(), + ASSERT_EQ(absl::get(created_expr.iter_range().expr_kind()).name(), "range"); ASSERT_EQ(created_expr.accu_var(), "accu_var"); - ASSERT_EQ(absl::get(created_expr.accu_init()->expr_kind()).name(), + ASSERT_EQ(absl::get(created_expr.accu_init().expr_kind()).name(), "init"); - ASSERT_EQ(absl::get(created_expr.loop_condition()->expr_kind()).name(), + ASSERT_EQ(absl::get(created_expr.loop_condition().expr_kind()).name(), "cond"); - ASSERT_EQ(absl::get(created_expr.loop_step()->expr_kind()).name(), + ASSERT_EQ(absl::get(created_expr.loop_step().expr_kind()).name(), "step"); - ASSERT_EQ(absl::get(created_expr.result()->expr_kind()).name(), + ASSERT_EQ(absl::get(created_expr.result().expr_kind()).name(), "result"); } @@ -147,30 +202,51 @@ TEST(AstTest, ComprehensionMutableConstruction) { Comprehension comprehension; comprehension.mutable_iter_range().set_expr_kind(Ident("var")); ASSERT_TRUE( - absl::holds_alternative(comprehension.iter_range()->expr_kind())); - ASSERT_EQ(absl::get(comprehension.iter_range()->expr_kind()).name(), + absl::holds_alternative(comprehension.iter_range().expr_kind())); + ASSERT_EQ(absl::get(comprehension.iter_range().expr_kind()).name(), "var"); comprehension.mutable_accu_init().set_expr_kind(Ident("var")); ASSERT_TRUE( - absl::holds_alternative(comprehension.accu_init()->expr_kind())); - ASSERT_EQ(absl::get(comprehension.accu_init()->expr_kind()).name(), + absl::holds_alternative(comprehension.accu_init().expr_kind())); + ASSERT_EQ(absl::get(comprehension.accu_init().expr_kind()).name(), "var"); comprehension.mutable_loop_condition().set_expr_kind(Ident("var")); ASSERT_TRUE(absl::holds_alternative( - comprehension.loop_condition()->expr_kind())); - ASSERT_EQ( - absl::get(comprehension.loop_condition()->expr_kind()).name(), - "var"); + comprehension.loop_condition().expr_kind())); + ASSERT_EQ(absl::get(comprehension.loop_condition().expr_kind()).name(), + "var"); comprehension.mutable_loop_step().set_expr_kind(Ident("var")); ASSERT_TRUE( - absl::holds_alternative(comprehension.loop_step()->expr_kind())); - ASSERT_EQ(absl::get(comprehension.loop_step()->expr_kind()).name(), + absl::holds_alternative(comprehension.loop_step().expr_kind())); + ASSERT_EQ(absl::get(comprehension.loop_step().expr_kind()).name(), "var"); comprehension.mutable_result().set_expr_kind(Ident("var")); ASSERT_TRUE( - absl::holds_alternative(comprehension.result()->expr_kind())); - ASSERT_EQ(absl::get(comprehension.result()->expr_kind()).name(), - "var"); + absl::holds_alternative(comprehension.result().expr_kind())); + ASSERT_EQ(absl::get(comprehension.result().expr_kind()).name(), "var"); +} + +TEST(AstTest, ComprehensionDefaults) { + Comprehension comprehension; + EXPECT_TRUE(comprehension.iter_var().empty()); + EXPECT_EQ(comprehension.iter_range(), Expr()); + EXPECT_TRUE(comprehension.accu_var().empty()); + EXPECT_EQ(comprehension.accu_init(), Expr()); + EXPECT_EQ(comprehension.loop_condition(), Expr()); + EXPECT_EQ(comprehension.loop_step(), Expr()); + EXPECT_EQ(comprehension.result(), Expr()); +} + +TEST(AstTest, ComprehenesionComparatorIterVar) { + Comprehension comprehension; + comprehension.set_iter_var("var"); + EXPECT_FALSE(comprehension == Comprehension()); +} + +TEST(AstTest, ComprehenesionComparatorAccuVar) { + Comprehension comprehension; + comprehension.set_accu_var("var"); + EXPECT_FALSE(comprehension == Comprehension()); } TEST(AstTest, ExprMoveTest) { @@ -182,6 +258,17 @@ TEST(AstTest, ExprMoveTest) { ASSERT_EQ(absl::get(new_expr.expr_kind()).name(), "var"); } +TEST(AstTest, ExprDefaults) { + Expr expr; + EXPECT_EQ(expr.const_expr(), Constant()); + EXPECT_EQ(expr.ident_expr(), Ident()); + EXPECT_EQ(expr.select_expr(), Select()); + EXPECT_EQ(expr.call_expr(), Call()); + EXPECT_EQ(expr.list_expr(), CreateList()); + EXPECT_EQ(expr.struct_expr(), CreateStruct()); + EXPECT_EQ(expr.comprehension_expr(), Comprehension()); +} + TEST(AstTest, ParsedExpr) { ParsedExpr parsed_expr; parsed_expr.set_expr(Expr(1, Ident("name"))); @@ -204,7 +291,7 @@ TEST(AstTest, ParsedExpr) { TEST(AstTest, ListTypeMutableConstruction) { ListType type; type.mutable_elem_type() = Type(PrimitiveType::kBool); - EXPECT_EQ(absl::get(type.elem_type()->type_kind()), + EXPECT_EQ(absl::get(type.elem_type().type_kind()), PrimitiveType::kBool); } @@ -212,19 +299,37 @@ TEST(AstTest, MapTypeMutableConstruction) { MapType type; type.mutable_key_type() = Type(PrimitiveType::kBool); type.mutable_value_type() = Type(PrimitiveType::kBool); - EXPECT_EQ(absl::get(type.key_type()->type_kind()), + EXPECT_EQ(absl::get(type.key_type().type_kind()), PrimitiveType::kBool); - EXPECT_EQ(absl::get(type.value_type()->type_kind()), + EXPECT_EQ(absl::get(type.value_type().type_kind()), PrimitiveType::kBool); } +TEST(AstTest, MapTypeComparatorKeyType) { + MapType type; + type.mutable_key_type() = Type(PrimitiveType::kBool); + EXPECT_FALSE(type == MapType()); +} + +TEST(AstTest, MapTypeComparatorValueType) { + MapType type; + type.mutable_value_type() = Type(PrimitiveType::kBool); + EXPECT_FALSE(type == MapType()); +} + TEST(AstTest, FunctionTypeMutableConstruction) { FunctionType type; type.mutable_result_type() = Type(PrimitiveType::kBool); - EXPECT_EQ(absl::get(type.result_type()->type_kind()), + EXPECT_EQ(absl::get(type.result_type().type_kind()), PrimitiveType::kBool); } +TEST(AstTest, FunctionTypeComparatorArgTypes) { + FunctionType type; + type.mutable_arg_types().emplace_back(Type()); + EXPECT_FALSE(type == FunctionType()); +} + TEST(AstTest, CheckedExpr) { CheckedExpr checked_expr; checked_expr.set_expr(Expr(1, Ident("name"))); @@ -248,6 +353,38 @@ TEST(AstTest, CheckedExpr) { EXPECT_EQ(checked_expr.expr_version(), "expr_version"); } +TEST(AstTest, ListTypeDefaults) { EXPECT_EQ(ListType().elem_type(), Type()); } + +TEST(AstTest, MapTypeDefaults) { + EXPECT_EQ(MapType().key_type(), Type()); + EXPECT_EQ(MapType().value_type(), Type()); +} + +TEST(AstTest, FunctionTypeDefaults) { + EXPECT_EQ(FunctionType().result_type(), Type()); +} + +TEST(AstTest, TypeDefaults) { + EXPECT_EQ(Type().null(), NullValue::kNullValue); + EXPECT_EQ(Type().primitive(), PrimitiveType::kPrimitiveTypeUnspecified); + EXPECT_EQ(Type().wrapper(), PrimitiveType::kPrimitiveTypeUnspecified); + EXPECT_EQ(Type().well_known(), WellKnownType::kWellKnownTypeUnspecified); + EXPECT_EQ(Type().list_type(), ListType()); + EXPECT_EQ(Type().map_type(), MapType()); + EXPECT_EQ(Type().function(), FunctionType()); + EXPECT_EQ(Type().message_type(), MessageType()); + EXPECT_EQ(Type().type_param(), ParamType()); + EXPECT_EQ(Type().type(), Type()); + EXPECT_EQ(Type().error_type(), ErrorType()); + EXPECT_EQ(Type().abstract_type(), AbstractType()); +} + +TEST(AstTest, TypeComparatorTest) { + Type type; + type.set_type_kind(std::make_unique(PrimitiveType::kBool)); + EXPECT_FALSE(type.type() == Type()); +} + } // namespace } // namespace internal } // namespace ast diff --git a/base/ast_utility.cc b/base/ast_utility.cc index 812470d8b..a4cd54691 100644 --- a/base/ast_utility.cc +++ b/base/ast_utility.cc @@ -36,25 +36,26 @@ namespace cel::ast::internal { absl::StatusOr ToNative(const google::api::expr::v1alpha1::Constant& constant) { switch (constant.constant_kind_case()) { case google::api::expr::v1alpha1::Constant::kNullValue: - return NullValue::kNullValue; + return Constant(NullValue::kNullValue); case google::api::expr::v1alpha1::Constant::kBoolValue: - return constant.bool_value(); + return Constant(constant.bool_value()); case google::api::expr::v1alpha1::Constant::kInt64Value: - return constant.int64_value(); + return Constant(constant.int64_value()); case google::api::expr::v1alpha1::Constant::kUint64Value: - return constant.uint64_value(); + return Constant(constant.uint64_value()); case google::api::expr::v1alpha1::Constant::kDoubleValue: - return constant.double_value(); + return Constant(constant.double_value()); case google::api::expr::v1alpha1::Constant::kStringValue: - return constant.string_value(); + return Constant(constant.string_value()); case google::api::expr::v1alpha1::Constant::kBytesValue: - return constant.bytes_value(); + return Constant(constant.bytes_value()); case google::api::expr::v1alpha1::Constant::kDurationValue: - return absl::Seconds(constant.duration_value().seconds()) + - absl::Nanoseconds(constant.duration_value().nanos()); + return Constant(absl::Seconds(constant.duration_value().seconds()) + + absl::Nanoseconds(constant.duration_value().nanos())); case google::api::expr::v1alpha1::Constant::kTimestampValue: - return absl::FromUnixSeconds(constant.timestamp_value().seconds()) + - absl::Nanoseconds(constant.timestamp_value().nanos()); + return Constant( + absl::FromUnixSeconds(constant.timestamp_value().seconds()) + + absl::Nanoseconds(constant.timestamp_value().nanos())); default: return absl::InvalidArgumentError( "Illegal type supplied for google::api::expr::v1alpha1::Constant."); diff --git a/base/ast_utility_test.cc b/base/ast_utility_test.cc index 059d627ba..0a6e7f138 100644 --- a/base/ast_utility_test.cc +++ b/base/ast_utility_test.cc @@ -44,8 +44,8 @@ TEST(AstUtilityTest, IdentToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - EXPECT_EQ(absl::get(native_expr->expr_kind()).name(), "name"); + ASSERT_TRUE(native_expr->has_ident_expr()); + EXPECT_EQ(native_expr->ident_expr().name(), "name"); } TEST(AstUtilityTest, SelectToNative) { @@ -62,12 +62,10 @@ TEST(AstUtilityTest, SelectToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind()); - ASSERT_TRUE( - absl::holds_alternative(native_select.operand()->expr_kind())); - EXPECT_EQ(absl::get(native_select.operand()->expr_kind()).name(), - "name"); + ASSERT_TRUE(native_expr->has_select_expr()); + auto& native_select = native_expr->select_expr(); + ASSERT_TRUE(native_select.operand().has_ident_expr()); + EXPECT_EQ(native_select.operand().ident_expr().name(), "name"); EXPECT_EQ(native_select.field(), "field"); EXPECT_TRUE(native_select.test_only()); } @@ -87,18 +85,17 @@ TEST(AstUtilityTest, CallToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - auto& native_call = absl::get(native_expr->expr_kind()); - ASSERT_TRUE( - absl::holds_alternative(native_call.target()->expr_kind())); - EXPECT_EQ(absl::get(native_call.target()->expr_kind()).name(), "name"); + ASSERT_TRUE(native_expr->has_call_expr()); + auto& native_call = native_expr->call_expr(); + ASSERT_TRUE(native_call.target().has_ident_expr()); + EXPECT_EQ(native_call.target().ident_expr().name(), "name"); EXPECT_EQ(native_call.function(), "function"); auto& native_arg1 = native_call.args()[0]; - ASSERT_TRUE(absl::holds_alternative(native_arg1.expr_kind())); - EXPECT_EQ(absl::get(native_arg1.expr_kind()).name(), "arg1"); + ASSERT_TRUE(native_arg1.has_ident_expr()); + EXPECT_EQ(native_arg1.ident_expr().name(), "arg1"); auto& native_arg2 = native_call.args()[1]; - ASSERT_TRUE(absl::holds_alternative(native_arg2.expr_kind())); - ASSERT_EQ(absl::get(native_arg2.expr_kind()).name(), "arg2"); + ASSERT_TRUE(native_arg2.has_ident_expr()); + ASSERT_EQ(native_arg2.ident_expr().name(), "arg2"); } TEST(AstUtilityTest, CreateListToNative) { @@ -114,14 +111,14 @@ TEST(AstUtilityTest, CreateListToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - auto& native_create_list = absl::get(native_expr->expr_kind()); + ASSERT_TRUE(native_expr->has_list_expr()); + auto& native_create_list = native_expr->list_expr(); auto& native_elem1 = native_create_list.elements()[0]; - ASSERT_TRUE(absl::holds_alternative(native_elem1.expr_kind())); - ASSERT_EQ(absl::get(native_elem1.expr_kind()).name(), "elem1"); + ASSERT_TRUE(native_elem1.has_ident_expr()); + ASSERT_EQ(native_elem1.ident_expr().name(), "elem1"); auto& native_elem2 = native_create_list.elements()[1]; - ASSERT_TRUE(absl::holds_alternative(native_elem2.expr_kind())); - ASSERT_EQ(absl::get(native_elem2.expr_kind()).name(), "elem2"); + ASSERT_TRUE(native_elem2.has_ident_expr()); + ASSERT_EQ(native_elem2.ident_expr().name(), "elem2"); } TEST(AstUtilityTest, CreateStructToNative) { @@ -145,29 +142,20 @@ TEST(AstUtilityTest, CreateStructToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - auto& native_struct = absl::get(native_expr->expr_kind()); + ASSERT_TRUE(native_expr->has_struct_expr()); + auto& native_struct = native_expr->struct_expr(); auto& native_entry1 = native_struct.entries()[0]; EXPECT_EQ(native_entry1.id(), 1); - ASSERT_TRUE(absl::holds_alternative(native_entry1.key_kind())); - ASSERT_EQ(absl::get(native_entry1.key_kind()), "key1"); - ASSERT_TRUE( - absl::holds_alternative(native_entry1.value()->expr_kind())); - ASSERT_EQ(absl::get(native_entry1.value()->expr_kind()).name(), - "value1"); + ASSERT_TRUE(native_entry1.has_field_key()); + ASSERT_EQ(native_entry1.field_key(), "key1"); + ASSERT_TRUE(native_entry1.value().has_ident_expr()); + ASSERT_EQ(native_entry1.value().ident_expr().name(), "value1"); auto& native_entry2 = native_struct.entries()[1]; EXPECT_EQ(native_entry2.id(), 2); - ASSERT_TRUE( - absl::holds_alternative>(native_entry2.key_kind())); - ASSERT_TRUE(absl::holds_alternative( - absl::get>(native_entry2.key_kind())->expr_kind())); - EXPECT_EQ(absl::get( - absl::get>(native_entry2.key_kind()) - ->expr_kind()) - .name(), - "key2"); - ASSERT_EQ(absl::get(native_entry2.value()->expr_kind()).name(), - "value2"); + ASSERT_TRUE(native_entry2.has_map_key()); + ASSERT_TRUE(native_entry2.map_key().has_ident_expr()); + EXPECT_EQ(native_entry2.map_key().ident_expr().name(), "key2"); + ASSERT_EQ(native_entry2.value().ident_expr().name(), "value2"); } TEST(AstUtilityTest, CreateStructError) { @@ -210,35 +198,22 @@ TEST(AstUtilityTest, ComprehensionToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - auto& native_comprehension = - absl::get(native_expr->expr_kind()); + ASSERT_TRUE(native_expr->has_comprehension_expr()); + auto& native_comprehension = native_expr->comprehension_expr(); EXPECT_EQ(native_comprehension.iter_var(), "iter_var"); - ASSERT_TRUE(absl::holds_alternative( - native_comprehension.iter_range()->expr_kind())); - EXPECT_EQ( - absl::get(native_comprehension.iter_range()->expr_kind()).name(), - "iter_range"); + ASSERT_TRUE(native_comprehension.iter_range().has_ident_expr()); + EXPECT_EQ(native_comprehension.iter_range().ident_expr().name(), + "iter_range"); EXPECT_EQ(native_comprehension.accu_var(), "accu_var"); - ASSERT_TRUE(absl::holds_alternative( - native_comprehension.accu_init()->expr_kind())); - EXPECT_EQ( - absl::get(native_comprehension.accu_init()->expr_kind()).name(), - "accu_init"); - ASSERT_TRUE(absl::holds_alternative( - native_comprehension.loop_condition()->expr_kind())); - EXPECT_EQ(absl::get(native_comprehension.loop_condition()->expr_kind()) - .name(), + ASSERT_TRUE(native_comprehension.accu_init().has_ident_expr()); + EXPECT_EQ(native_comprehension.accu_init().ident_expr().name(), "accu_init"); + ASSERT_TRUE(native_comprehension.loop_condition().has_ident_expr()); + EXPECT_EQ(native_comprehension.loop_condition().ident_expr().name(), "loop_condition"); - ASSERT_TRUE(absl::holds_alternative( - native_comprehension.loop_step()->expr_kind())); - EXPECT_EQ( - absl::get(native_comprehension.loop_step()->expr_kind()).name(), - "loop_step"); - ASSERT_TRUE(absl::holds_alternative( - native_comprehension.result()->expr_kind())); - EXPECT_EQ(absl::get(native_comprehension.result()->expr_kind()).name(), - "result"); + ASSERT_TRUE(native_comprehension.loop_step().has_ident_expr()); + EXPECT_EQ(native_comprehension.loop_step().ident_expr().name(), "loop_step"); + ASSERT_TRUE(native_comprehension.result().has_ident_expr()); + EXPECT_EQ(native_comprehension.result().ident_expr().name(), "result"); } TEST(AstUtilityTest, ConstantToNative) { @@ -248,10 +223,10 @@ TEST(AstUtilityTest, ConstantToNative) { auto native_expr = ToNative(expr); - ASSERT_TRUE(absl::holds_alternative(native_expr->expr_kind())); - auto& native_constant = absl::get(native_expr->expr_kind()); - ASSERT_TRUE(absl::holds_alternative(native_constant)); - EXPECT_EQ(absl::get(native_constant), NullValue::kNullValue); + ASSERT_TRUE(native_expr->has_const_expr()); + auto& native_constant = native_expr->const_expr(); + ASSERT_TRUE(native_constant.has_null_value()); + EXPECT_EQ(native_constant.null_value(), NullValue::kNullValue); } TEST(AstUtilityTest, ConstantBoolTrueToNative) { @@ -260,8 +235,8 @@ TEST(AstUtilityTest, ConstantBoolTrueToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_TRUE(absl::get(*native_constant)); + ASSERT_TRUE(native_constant->has_bool_value()); + EXPECT_TRUE(native_constant->bool_value()); } TEST(AstUtilityTest, ConstantBoolFalseToNative) { @@ -270,8 +245,8 @@ TEST(AstUtilityTest, ConstantBoolFalseToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_FALSE(absl::get(*native_constant)); + ASSERT_TRUE(native_constant->has_bool_value()); + EXPECT_FALSE(native_constant->bool_value()); } TEST(AstUtilityTest, ConstantInt64ToNative) { @@ -280,9 +255,9 @@ TEST(AstUtilityTest, ConstantInt64ToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - ASSERT_FALSE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), -23); + ASSERT_TRUE(native_constant->has_int64_value()); + ASSERT_FALSE(native_constant->has_uint64_value()); + EXPECT_EQ(native_constant->int64_value(), -23); } TEST(AstUtilityTest, ConstantUint64ToNative) { @@ -291,9 +266,9 @@ TEST(AstUtilityTest, ConstantUint64ToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - ASSERT_FALSE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), 23); + ASSERT_TRUE(native_constant->has_uint64_value()); + ASSERT_FALSE(native_constant->has_int64_value()); + EXPECT_EQ(native_constant->uint64_value(), 23); } TEST(AstUtilityTest, ConstantDoubleToNative) { @@ -302,8 +277,8 @@ TEST(AstUtilityTest, ConstantDoubleToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), 12.34); + ASSERT_TRUE(native_constant->has_double_value()); + EXPECT_EQ(native_constant->double_value(), 12.34); } TEST(AstUtilityTest, ConstantStringToNative) { @@ -312,8 +287,8 @@ TEST(AstUtilityTest, ConstantStringToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), "string"); + ASSERT_TRUE(native_constant->has_string_value()); + EXPECT_EQ(native_constant->string_value(), "string"); } TEST(AstUtilityTest, ConstantBytesToNative) { @@ -322,8 +297,8 @@ TEST(AstUtilityTest, ConstantBytesToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), "bytes"); + ASSERT_TRUE(native_constant->has_string_value()); + EXPECT_EQ(native_constant->string_value(), "bytes"); } TEST(AstUtilityTest, ConstantDurationToNative) { @@ -333,8 +308,8 @@ TEST(AstUtilityTest, ConstantDurationToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), + ASSERT_TRUE(native_constant->has_duration_value()); + EXPECT_EQ(native_constant->duration_value(), absl::Seconds(123) + absl::Nanoseconds(456)); } @@ -345,8 +320,8 @@ TEST(AstUtilityTest, ConstantTimestampToNative) { auto native_constant = ToNative(constant); - ASSERT_TRUE(absl::holds_alternative(*native_constant)); - EXPECT_EQ(absl::get(*native_constant), + ASSERT_TRUE(native_constant->has_time_value()); + EXPECT_EQ(native_constant->time_value(), absl::FromUnixSeconds(123) + absl::Nanoseconds(456)); } @@ -395,12 +370,9 @@ TEST(AstUtilityTest, SourceInfoToNative) { EXPECT_EQ(native_source_info->line_offsets(), std::vector({1, 2})); EXPECT_EQ(native_source_info->positions().at(1), 2); EXPECT_EQ(native_source_info->positions().at(3), 4); - ASSERT_TRUE(absl::holds_alternative( - native_source_info->macro_calls().at(1).expr_kind())); - ASSERT_EQ( - absl::get(native_source_info->macro_calls().at(1).expr_kind()) - .name(), - "name"); + ASSERT_TRUE(native_source_info->macro_calls().at(1).has_ident_expr()); + ASSERT_EQ(native_source_info->macro_calls().at(1).ident_expr().name(), + "name"); } TEST(AstUtilityTest, ParsedExprToNative) { @@ -425,21 +397,16 @@ TEST(AstUtilityTest, ParsedExprToNative) { auto native_parsed_expr = ToNative(parsed_expr); - ASSERT_TRUE( - absl::holds_alternative(native_parsed_expr->expr().expr_kind())); - ASSERT_EQ(absl::get(native_parsed_expr->expr().expr_kind()).name(), - "name"); + ASSERT_TRUE(native_parsed_expr->expr().has_ident_expr()); + ASSERT_EQ(native_parsed_expr->expr().ident_expr().name(), "name"); auto& native_source_info = native_parsed_expr->source_info(); EXPECT_EQ(native_source_info.syntax_version(), "version"); EXPECT_EQ(native_source_info.location(), "location"); EXPECT_EQ(native_source_info.line_offsets(), std::vector({1, 2})); EXPECT_EQ(native_source_info.positions().at(1), 2); EXPECT_EQ(native_source_info.positions().at(3), 4); - ASSERT_TRUE(absl::holds_alternative( - native_source_info.macro_calls().at(1).expr_kind())); - ASSERT_EQ(absl::get(native_source_info.macro_calls().at(1).expr_kind()) - .name(), - "name"); + ASSERT_TRUE(native_source_info.macro_calls().at(1).has_ident_expr()); + ASSERT_EQ(native_source_info.macro_calls().at(1).ident_expr().name(), "name"); } TEST(AstUtilityTest, PrimitiveTypeUnspecifiedToNative) { @@ -448,9 +415,8 @@ TEST(AstUtilityTest, PrimitiveTypeUnspecifiedToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kPrimitiveTypeUnspecified); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kPrimitiveTypeUnspecified); } TEST(AstUtilityTest, PrimitiveTypeBoolToNative) { @@ -459,9 +425,8 @@ TEST(AstUtilityTest, PrimitiveTypeBoolToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kBool); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kBool); } TEST(AstUtilityTest, PrimitiveTypeInt64ToNative) { @@ -470,9 +435,8 @@ TEST(AstUtilityTest, PrimitiveTypeInt64ToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kInt64); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kInt64); } TEST(AstUtilityTest, PrimitiveTypeUint64ToNative) { @@ -481,9 +445,8 @@ TEST(AstUtilityTest, PrimitiveTypeUint64ToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kUint64); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kUint64); } TEST(AstUtilityTest, PrimitiveTypeDoubleToNative) { @@ -492,9 +455,8 @@ TEST(AstUtilityTest, PrimitiveTypeDoubleToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kDouble); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kDouble); } TEST(AstUtilityTest, PrimitiveTypeStringToNative) { @@ -503,9 +465,8 @@ TEST(AstUtilityTest, PrimitiveTypeStringToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kString); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kString); } TEST(AstUtilityTest, PrimitiveTypeBytesToNative) { @@ -514,9 +475,8 @@ TEST(AstUtilityTest, PrimitiveTypeBytesToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - PrimitiveType::kBytes); + ASSERT_TRUE(native_type->has_primitive()); + EXPECT_EQ(native_type->primitive(), PrimitiveType::kBytes); } TEST(AstUtilityTest, PrimitiveTypeError) { @@ -537,8 +497,8 @@ TEST(AstUtilityTest, WellKnownTypeUnspecifiedToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownType::kWellKnownTypeUnspecified); } @@ -548,9 +508,8 @@ TEST(AstUtilityTest, WellKnownTypeAnyToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - WellKnownType::kAny); + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownType::kAny); } TEST(AstUtilityTest, WellKnownTypeTimestampToNative) { @@ -559,9 +518,8 @@ TEST(AstUtilityTest, WellKnownTypeTimestampToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - WellKnownType::kTimestamp); + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownType::kTimestamp); } TEST(AstUtilityTest, WellKnownTypeDuraionToNative) { @@ -570,9 +528,8 @@ TEST(AstUtilityTest, WellKnownTypeDuraionToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - WellKnownType::kDuration); + ASSERT_TRUE(native_type->has_well_known()); + EXPECT_EQ(native_type->well_known(), WellKnownType::kDuration); } TEST(AstUtilityTest, WellKnownTypeError) { @@ -594,12 +551,10 @@ TEST(AstUtilityTest, ListTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - auto& native_list_type = absl::get(native_type->type_kind()); - ASSERT_TRUE(absl::holds_alternative( - native_list_type.elem_type()->type_kind())); - EXPECT_EQ(absl::get(native_list_type.elem_type()->type_kind()), - PrimitiveType::kBool); + ASSERT_TRUE(native_type->has_list_type()); + auto& native_list_type = native_type->list_type(); + ASSERT_TRUE(native_list_type.elem_type().has_primitive()); + EXPECT_EQ(native_list_type.elem_type().primitive(), PrimitiveType::kBool); } TEST(AstUtilityTest, MapTypeToNative) { @@ -615,16 +570,12 @@ TEST(AstUtilityTest, MapTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - auto& native_map_type = absl::get(native_type->type_kind()); - ASSERT_TRUE(absl::holds_alternative( - native_map_type.key_type()->type_kind())); - EXPECT_EQ(absl::get(native_map_type.key_type()->type_kind()), - PrimitiveType::kBool); - ASSERT_TRUE(absl::holds_alternative( - native_map_type.value_type()->type_kind())); - EXPECT_EQ(absl::get(native_map_type.value_type()->type_kind()), - PrimitiveType::kDouble); + ASSERT_TRUE(native_type->has_map_type()); + auto& native_map_type = native_type->map_type(); + ASSERT_TRUE(native_map_type.key_type().has_primitive()); + EXPECT_EQ(native_map_type.key_type().primitive(), PrimitiveType::kBool); + ASSERT_TRUE(native_map_type.value_type().has_primitive()); + EXPECT_EQ(native_map_type.value_type().primitive(), PrimitiveType::kDouble); } TEST(AstUtilityTest, FunctionTypeToNative) { @@ -641,23 +592,16 @@ TEST(AstUtilityTest, FunctionTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - auto& native_function_type = - absl::get(native_type->type_kind()); - ASSERT_TRUE(absl::holds_alternative( - native_function_type.result_type()->type_kind())); - EXPECT_EQ( - absl::get(native_function_type.result_type()->type_kind()), - PrimitiveType::kBool); - ASSERT_TRUE(absl::holds_alternative( - native_function_type.arg_types().at(0).type_kind())); - EXPECT_EQ(absl::get( - native_function_type.arg_types().at(0).type_kind()), + ASSERT_TRUE(native_type->has_function()); + auto& native_function_type = native_type->function(); + ASSERT_TRUE(native_function_type.result_type().has_primitive()); + EXPECT_EQ(native_function_type.result_type().primitive(), + PrimitiveType::kBool); + ASSERT_TRUE(native_function_type.arg_types().at(0).has_primitive()); + EXPECT_EQ(native_function_type.arg_types().at(0).primitive(), PrimitiveType::kDouble); - ASSERT_TRUE(absl::holds_alternative( - native_function_type.arg_types().at(1).type_kind())); - EXPECT_EQ(absl::get( - native_function_type.arg_types().at(1).type_kind()), + ASSERT_TRUE(native_function_type.arg_types().at(1).has_primitive()); + EXPECT_EQ(native_function_type.arg_types().at(1).primitive(), PrimitiveType::kString); } @@ -675,19 +619,14 @@ TEST(AstUtilityTest, AbstractTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - auto& native_abstract_type = - absl::get(native_type->type_kind()); + ASSERT_TRUE(native_type->has_abstract_type()); + auto& native_abstract_type = native_type->abstract_type(); EXPECT_EQ(native_abstract_type.name(), "name"); - ASSERT_TRUE(absl::holds_alternative( - native_abstract_type.parameter_types().at(0).type_kind())); - EXPECT_EQ(absl::get( - native_abstract_type.parameter_types().at(0).type_kind()), + ASSERT_TRUE(native_abstract_type.parameter_types().at(0).has_primitive()); + EXPECT_EQ(native_abstract_type.parameter_types().at(0).primitive(), PrimitiveType::kDouble); - ASSERT_TRUE(absl::holds_alternative( - native_abstract_type.parameter_types().at(1).type_kind())); - EXPECT_EQ(absl::get( - native_abstract_type.parameter_types().at(1).type_kind()), + ASSERT_TRUE(native_abstract_type.parameter_types().at(1).has_primitive()); + EXPECT_EQ(native_abstract_type.parameter_types().at(1).primitive(), PrimitiveType::kString); } @@ -697,7 +636,7 @@ TEST(AstUtilityTest, DynamicTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); + ASSERT_TRUE(native_type->has_dyn()); } TEST(AstUtilityTest, NullTypeToNative) { @@ -706,9 +645,8 @@ TEST(AstUtilityTest, NullTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()), - NullValue::kNullValue); + ASSERT_TRUE(native_type->has_null()); + EXPECT_EQ(native_type->null(), NullValue::kNullValue); } TEST(AstUtilityTest, PrimitiveTypeWrapperToNative) { @@ -717,10 +655,8 @@ TEST(AstUtilityTest, PrimitiveTypeWrapperToNative) { auto native_type = ToNative(type); - ASSERT_TRUE( - absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()).type(), - PrimitiveType::kBool); + ASSERT_TRUE(native_type->has_wrapper()); + EXPECT_EQ(native_type->wrapper(), PrimitiveType::kBool); } TEST(AstUtilityTest, MessageTypeToNative) { @@ -729,8 +665,8 @@ TEST(AstUtilityTest, MessageTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()).type(), "message"); + ASSERT_TRUE(native_type->has_message_type()); + EXPECT_EQ(native_type->message_type().type(), "message"); } TEST(AstUtilityTest, ParamTypeToNative) { @@ -739,8 +675,8 @@ TEST(AstUtilityTest, ParamTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE(absl::holds_alternative(native_type->type_kind())); - EXPECT_EQ(absl::get(native_type->type_kind()).type(), "param"); + ASSERT_TRUE(native_type->has_type_param()); + EXPECT_EQ(native_type->type_param().type(), "param"); } TEST(AstUtilityTest, NestedTypeToNative) { @@ -749,10 +685,8 @@ TEST(AstUtilityTest, NestedTypeToNative) { auto native_type = ToNative(type); - ASSERT_TRUE( - absl::holds_alternative>(native_type->type_kind())); - EXPECT_TRUE(absl::holds_alternative( - absl::get>(native_type->type_kind())->type_kind())); + ASSERT_TRUE(native_type->has_type()); + EXPECT_TRUE(native_type->type().has_dyn()); } TEST(AstUtilityTest, TypeError) { @@ -780,7 +714,7 @@ TEST(AstUtilityTest, ReferenceToNative) { EXPECT_EQ(native_reference->name(), "name"); EXPECT_EQ(native_reference->overload_id(), std::vector({"id1", "id2"})); - EXPECT_TRUE(absl::get(native_reference->value())); + EXPECT_TRUE(native_reference->value().bool_value()); } TEST(AstUtilityTest, CheckedExprToNative) { @@ -822,24 +756,18 @@ TEST(AstUtilityTest, CheckedExprToNative) { EXPECT_EQ(native_checked_expr->reference_map().at(1).name(), "name"); EXPECT_EQ(native_checked_expr->reference_map().at(1).overload_id(), std::vector({"id1", "id2"})); - EXPECT_TRUE( - absl::get(native_checked_expr->reference_map().at(1).value())); + EXPECT_TRUE(native_checked_expr->reference_map().at(1).value().bool_value()); auto& native_source_info = native_checked_expr->source_info(); EXPECT_EQ(native_source_info.syntax_version(), "version"); EXPECT_EQ(native_source_info.location(), "location"); EXPECT_EQ(native_source_info.line_offsets(), std::vector({1, 2})); EXPECT_EQ(native_source_info.positions().at(1), 2); EXPECT_EQ(native_source_info.positions().at(3), 4); - ASSERT_TRUE(absl::holds_alternative( - native_source_info.macro_calls().at(1).expr_kind())); - ASSERT_EQ(absl::get(native_source_info.macro_calls().at(1).expr_kind()) - .name(), - "name"); + ASSERT_TRUE(native_source_info.macro_calls().at(1).has_ident_expr()); + ASSERT_EQ(native_source_info.macro_calls().at(1).ident_expr().name(), "name"); EXPECT_EQ(native_checked_expr->expr_version(), "version"); - ASSERT_TRUE( - absl::holds_alternative(native_checked_expr->expr().expr_kind())); - EXPECT_EQ(absl::get(native_checked_expr->expr().expr_kind()).name(), - "expr"); + ASSERT_TRUE(native_checked_expr->expr().has_ident_expr()); + EXPECT_EQ(native_checked_expr->expr().ident_expr().name(), "expr"); } } // namespace From b7f9d7280ca279b61844715c2b254f7a8b559e2c Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 16 Jun 2022 21:37:52 +0000 Subject: [PATCH 007/303] Internal change PiperOrigin-RevId: 455469729 --- base/BUILD | 25 +- base/handle.h | 125 ++++-- base/internal/BUILD | 51 ++- base/internal/data.h | 402 ++++++++++++++++++ base/internal/handle.h | 60 +++ base/internal/handle.post.h | 126 ------ base/internal/handle.pre.h | 173 -------- base/internal/managed_memory.cc | 77 ++++ base/internal/managed_memory.h | 373 ++++++++++++++++ ...{memory_manager.pre.h => memory_manager.h} | 36 +- base/internal/memory_manager.post.h | 50 --- base/internal/{type.pre.h => type.h} | 62 ++- base/internal/type.post.h | 190 --------- base/internal/value.h | 159 +++++++ base/internal/value.post.h | 314 -------------- base/internal/value.pre.h | 241 ----------- base/kind.h | 7 +- base/managed_memory.h | 36 ++ base/memory_manager.cc | 253 ++--------- base/memory_manager.h | 278 +++--------- base/type.cc | 288 ++++++++++++- base/type.h | 334 +++++++++++---- base/type_factory.cc | 77 +--- base/type_factory.h | 10 - base/types/any_type.cc | 21 +- base/types/any_type.h | 40 +- base/types/bool_type.cc | 21 +- base/types/bool_type.h | 41 +- base/types/bytes_type.cc | 21 +- base/types/bytes_type.h | 41 +- base/types/double_type.cc | 21 +- base/types/double_type.h | 41 +- base/types/duration_type.cc | 21 +- base/types/duration_type.h | 41 +- base/types/dyn_type.cc | 21 +- base/types/dyn_type.h | 40 +- base/types/enum_type.cc | 21 + base/types/enum_type.h | 31 +- base/types/error_type.cc | 21 +- base/types/error_type.h | 41 +- base/types/int_type.cc | 21 +- base/types/int_type.h | 41 +- base/types/list_type.cc | 16 +- base/types/list_type.h | 41 +- base/types/map_type.cc | 20 +- base/types/map_type.h | 45 +- base/types/null_type.cc | 21 +- base/types/null_type.h | 42 +- base/types/string_type.cc | 21 +- base/types/string_type.h | 41 +- base/types/struct_type.cc | 21 + base/types/struct_type.h | 30 +- base/types/timestamp_type.cc | 21 +- base/types/timestamp_type.h | 43 +- base/types/type_type.cc | 21 +- base/types/type_type.h | 42 +- base/types/uint_type.cc | 21 +- base/types/uint_type.h | 41 +- base/value.cc | 343 ++++++++++++++- base/value.h | 295 ++++++++++--- base/value_factory.cc | 83 +--- base/value_factory.h | 54 ++- base/value_test.cc | 246 +++++------ base/values/bool_value.cc | 32 +- base/values/bool_value.h | 49 +-- base/values/bytes_value.cc | 156 ++++--- base/values/bytes_value.h | 144 +++---- base/values/double_value.cc | 31 +- base/values/double_value.h | 50 +-- base/values/duration_value.cc | 31 +- base/values/duration_value.h | 49 +-- base/values/enum_value.cc | 22 +- base/values/enum_value.h | 85 ++-- base/values/error_value.cc | 25 +- base/values/error_value.h | 46 +- base/values/int_value.cc | 31 +- base/values/int_value.h | 45 +- base/values/list_value.cc | 14 +- base/values/list_value.h | 41 +- base/values/map_value.cc | 14 +- base/values/map_value.h | 41 +- base/values/null_value.cc | 36 +- base/values/null_value.h | 47 +- base/values/string_value.cc | 166 ++++---- base/values/string_value.h | 181 +++----- base/values/struct_value.cc | 11 + base/values/struct_value.h | 35 +- base/values/timestamp_value.cc | 31 +- base/values/timestamp_value.h | 53 +-- base/values/type_value.cc | 24 +- base/values/type_value.h | 49 ++- base/values/uint_value.cc | 31 +- base/values/uint_value.h | 45 +- extensions/protobuf/memory_manager.cc | 28 +- extensions/protobuf/memory_manager.h | 4 +- 95 files changed, 3851 insertions(+), 3597 deletions(-) create mode 100644 base/internal/data.h create mode 100644 base/internal/handle.h delete mode 100644 base/internal/handle.post.h delete mode 100644 base/internal/handle.pre.h create mode 100644 base/internal/managed_memory.cc create mode 100644 base/internal/managed_memory.h rename base/internal/{memory_manager.pre.h => memory_manager.h} (51%) delete mode 100644 base/internal/memory_manager.post.h rename base/internal/{type.pre.h => type.h} (51%) delete mode 100644 base/internal/type.post.h create mode 100644 base/internal/value.h delete mode 100644 base/internal/value.post.h delete mode 100644 base/internal/value.pre.h create mode 100644 base/managed_memory.h diff --git a/base/BUILD b/base/BUILD index f6b0157dd..33f2f8f1d 100644 --- a/base/BUILD +++ b/base/BUILD @@ -23,9 +23,11 @@ cc_library( name = "handle", hdrs = ["handle.h"], deps = [ + ":memory_manager", + "//base/internal:data", "//base/internal:handle", - "//internal:casts", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/utility", ], ) @@ -48,20 +50,26 @@ cc_test( ], ) +cc_library( + name = "managed_memory", + hdrs = ["managed_memory.h"], + deps = ["//base/internal:managed_memory"], +) + cc_library( name = "memory_manager", srcs = ["memory_manager.cc"], hdrs = ["memory_manager.h"], deps = [ + ":managed_memory", + "//base/internal:data", "//base/internal:memory_manager", "//internal:no_destructor", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", ], ) @@ -111,14 +119,17 @@ cc_library( deps = [ ":handle", ":kind", - ":memory_manager", + "//base/internal:data", "//base/internal:type", - "//internal:no_destructor", + "//internal:casts", "//internal:rtti", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/utility", ], ) @@ -209,16 +220,16 @@ cc_library( deps = [ ":handle", ":kind", - ":memory_manager", ":type", + "//base/internal:data", "//base/internal:value", "//internal:casts", - "//internal:no_destructor", "//internal:rtti", "//internal:strings", "//internal:time", "//internal:utf8", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/base/handle.h b/base/handle.h index 18124b908..460b2797a 100644 --- a/base/handle.h +++ b/base/handle.h @@ -20,13 +20,13 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" -#include "base/internal/handle.pre.h" // IWYU pragma: export -#include "internal/casts.h" +#include "absl/utility/utility.h" +#include "base/internal/data.h" +#include "base/internal/handle.h" // IWYU pragma: export +#include "base/memory_manager.h" namespace cel { -class MemoryManager; - template class Persistent; @@ -132,17 +132,17 @@ class Persistent final : private base_internal::HandlePolicy { // Is checks wether `T` is an instance of `F`. template bool Is() const { - return impl_.template Is(); + return static_cast(*this) && F::Is(static_cast(**this)); } T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_ASSERT(static_cast(*this)); - return internal::down_cast(*impl_); + return static_cast(*impl_.get()); } T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_ASSERT(static_cast(*this)); - return internal::down_cast(impl_.operator->()); + return static_cast(impl_.get()); } // Tests whether the handle is not empty, returning false if it is empty. @@ -152,8 +152,28 @@ class Persistent final : private base_internal::HandlePolicy { std::swap(lhs.impl_, rhs.impl_); } - friend bool operator==(const Persistent& lhs, const Persistent& rhs) { - return lhs.impl_ == rhs.impl_; + bool operator==(const Persistent& other) const { + return impl_ == other.impl_; + } + + template + std::enable_if_t, + std::is_convertible>, + bool> + operator==(const Persistent& other) const { + return impl_ == other.impl_; + } + + bool operator!=(const Persistent& other) const { + return !operator==(other); + } + + template + std::enable_if_t, + std::is_convertible>, + bool> + operator!=(const Persistent& other) const { + return !operator==(other); } template @@ -166,54 +186,67 @@ class Persistent final : private base_internal::HandlePolicy { friend class Persistent; template friend struct base_internal::HandleFactory; - template - friend bool base_internal::IsManagedHandle(const Persistent& handle); - template - friend bool base_internal::IsUnmanagedHandle(const Persistent& handle); - template - friend bool base_internal::IsInlinedHandle(const Persistent& handle); - template - friend MemoryManager& base_internal::GetMemoryManager( - const Persistent& handle); template - explicit Persistent(base_internal::HandleInPlace, Args&&... args) + explicit Persistent(absl::in_place_t in_place, Args&&... args) : impl_(std::forward(args)...) {} Handle impl_; }; -template -std::enable_if_t, bool> operator==( - const Persistent& lhs, const Persistent& rhs) { - return lhs == rhs.template As(); -} +} // namespace cel + +// ----------------------------------------------------------------------------- +// Internal implementation details. -template -std::enable_if_t, bool> operator==( - const Persistent& lhs, const Persistent& rhs) { - return rhs == lhs.template As(); -} +namespace cel::base_internal { template -bool operator!=(const Persistent& lhs, const Persistent& rhs) { - return !operator==(lhs, rhs); -} - -template -std::enable_if_t, bool> operator!=( - const Persistent& lhs, const Persistent& rhs) { - return !operator==(lhs, rhs); -} - -template -std::enable_if_t, bool> operator!=( - const Persistent& lhs, const Persistent& rhs) { - return !operator==(lhs, rhs); -} +struct HandleFactory { + // Constructs a persistent handle whose underlying object is stored in the + // handle itself. + template + static std::enable_if_t, Persistent> Make( + Args&&... args) { + static_assert(std::is_base_of_v, "T is not derived from Data"); + static_assert(std::is_base_of_v, "F is not derived from T"); + return Persistent(absl::in_place, absl::in_place_type, + std::forward(args)...); + } + // Constructs a persistent handle whose underlying object is stored in the + // handle itself. + template + static std::enable_if_t, void> MakeAt( + void* address, Args&&... args) { + static_assert(std::is_base_of_v, "T is not derived from Data"); + static_assert(std::is_base_of_v, "F is not derived from T"); + ::new (address) Persistent(absl::in_place, absl::in_place_type, + std::forward(args)...); + } -} // namespace cel + // Constructs a persistent handle whose underlying object is heap allocated + // and potentially reference counted, depending on the memory manager + // implementation. + template + static std::enable_if_t, Persistent> Make( + MemoryManager& memory_manager, Args&&... args) { + static_assert(std::is_base_of_v, "T is not derived from Data"); + static_assert(std::is_base_of_v, "F is not derived from T"); +#if defined(__cpp_lib_is_pointer_interconvertible) && \ + __cpp_lib_is_pointer_interconvertible >= 201907L + // Only available in C++20. + static_assert(std::is_pointer_interconvertible_base_of_v, + "F must be pointer interconvertible to Data"); +#endif + auto managed_memory = memory_manager.New(std::forward(args)...); + if (ABSL_PREDICT_FALSE(managed_memory == nullptr)) { + return Persistent(); + } + return Persistent(absl::in_place, + *base_internal::ManagedMemoryRelease(managed_memory)); + } +}; -#include "base/internal/handle.post.h" // IWYU pragma: export +} // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ diff --git a/base/internal/BUILD b/base/internal/BUILD index 3b2cdd738..77759e1e9 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -16,25 +16,42 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "data", + hdrs = ["data.h"], + deps = [ + "//base:kind", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", + ], +) + # These headers should only ever be used by ../handle.h. They are here to avoid putting # large amounts of implementation details in public headers. cc_library( name = "handle", - textual_hdrs = [ - "handle.pre.h", - "handle.post.h", + hdrs = ["handle.h"], + deps = [ + ":data", ], +) + +cc_library( + name = "managed_memory", + srcs = ["managed_memory.cc"], + hdrs = ["managed_memory.h"], deps = [ - "//base:memory_manager", + ":data", + "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", ], ) cc_library( name = "memory_manager", - textual_hdrs = [ - "memory_manager.pre.h", - "memory_manager.post.h", + hdrs = [ + "memory_manager.h", ], ) @@ -61,33 +78,31 @@ cc_library( cc_library( name = "type", textual_hdrs = [ - "type.pre.h", - "type.post.h", + "type.h", ], deps = [ + ":data", "//base:handle", + "//base:kind", + "//internal:casts", "//internal:rtti", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/numeric:bits", ], ) cc_library( name = "value", textual_hdrs = [ - "value.pre.h", - "value.post.h", + "value.h", ], deps = [ + ":data", "//base:handle", - "//internal:casts", + "//base:type", "//internal:rtti", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", ], ) diff --git a/base/internal/data.h b/base/internal/data.h new file mode 100644 index 000000000..73f53e0b7 --- /dev/null +++ b/base/internal/data.h @@ -0,0 +1,402 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_DATA_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_DATA_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/numeric/bits.h" +#include "base/kind.h" + +namespace cel::base_internal { + +// Number of bits to shift to store kind. +inline constexpr int kKindShift = sizeof(uintptr_t) * 8 - 8; +// Mask that has all bits set except the most significant bit. +inline constexpr uint8_t kKindMask = (uint8_t{1} << 7) - 1; + +// uintptr_t with the least significant bit set. +inline constexpr uintptr_t kStoredInline = uintptr_t{1} << 0; +// uintptr_t with the second to least significant bit set. +inline constexpr uintptr_t kPointerArenaAllocated = uintptr_t{1} << 1; +// Mask that has all bits set except for `kPointerArenaAllocated`. +inline constexpr uintptr_t kPointerMask = ~kPointerArenaAllocated; +// uintptr_t with the most significant bit set. +inline constexpr uintptr_t kArenaAllocated = uintptr_t{1} + << (sizeof(uintptr_t) * 8 - 1); +inline constexpr uintptr_t kReferenceCounted = 1; +// uintptr_t with all bits set except for the most significant byte. +inline constexpr uintptr_t kReferenceCountMask = + kArenaAllocated | ((uintptr_t{1} << (sizeof(uintptr_t) * 8 - 8)) - 1); +inline constexpr uintptr_t kReferenceCountMax = + ((uintptr_t{1} << (sizeof(uintptr_t) * 8 - 8)) - 1); + +// uintptr_t with the 8th bit set. Used by inline data to indicate it is +// trivially copyable. +inline constexpr uintptr_t kTriviallyCopyable = 1 << 8; +// uintptr_t with the 9th bit set. Used by inline data to indicate it is +// trivially destuctible. +inline constexpr uintptr_t kTriviallyDestructible = 1 << 9; + +// We assert some expectations we have around alignment, size, and trivial +// destructability. +static_assert(sizeof(uintptr_t) == sizeof(std::atomic), + "uintptr_t and std::atomic must have the same size"); +static_assert(sizeof(void*) == sizeof(uintptr_t), + "void* and uintptr_t must have the same size"); +static_assert(std::is_trivially_destructible_v>, + "std::atomic must be trivially destructible"); + +enum class DataLocality { + kNull = 0, + kStoredInline, + kReferenceCounted, + kArenaAllocated, +}; + +// Empty base class of all classes that can be managed by handles. +// +// All `Data` implementations have a size of at least `sizeof(uintptr_t)`, have +// a `uintptr_t` at offset 0, and have an alignment that is at most +// `alignof(std::max_align_t)`. +// +// `Data` implementations are split into two categories: those stored inline and +// those allocated separately on the heap. This detail is not exposed to users +// and is managed entirely by the handles. We use a novel approach where given a +// pointer to some instantiated Data we can determine whether it is stored in a +// handle or allocated separately on the heap. If it is allocated on the heap we +// can then determine if it was allocated in an arena or if it is reference +// counted. We can also determine the `Kind` of data. +// +// We can determine whether data is stored directly in a handle by reading a +// `uintptr_t` at offset 0. If the least significant bit is set, this data is +// stored inside a handle. We rely on the fact that C++ places the virtual +// pointer to the virtual function table at offset 0 and it should be aligned to +// at least `sizeof(void*)`. +class Data {}; + +// Empty base class indicating class must be stored directly in the handle and +// not allocated separately on the heap. +// +// For inline data, Kind is stored in the most significant byte of `metadata`. +class InlineData /* : public Data */ { + // uintptr_t metadata + + public: + static void* operator new(size_t) = delete; + static void* operator new[](size_t) = delete; + + static void operator delete(void*) = delete; + static void operator delete[](void*) = delete; +}; + +// Used purely for a static_assert. +constexpr size_t HeapDataMetadataAndReferenceCountOffset(); + +// Base class indicating class must be allocated on the heap and not stored +// directly in a handle. +// +// For heap data, Kind is stored in the most significant byte of +// `metadata_and_reference_count`. If heap data was arena allocated, the most +// significant bit of the most significant byte is set. This property, combined +// with twos complement integers, allows us to easily detect incorrect reference +// counting as the reference count will be negative. +class HeapData /* : public Data */ { + // uintptr_t vptr + // std::atomic metadata_and_reference_count + + public: + HeapData(const HeapData&) = delete; + HeapData(HeapData&&) = delete; + + virtual ~HeapData() = default; + + HeapData& operator=(const HeapData&) = delete; + HeapData& operator=(HeapData&&) = delete; + + protected: + explicit HeapData(Kind kind) + : metadata_and_reference_count_(static_cast(kind) + << kKindShift) {} + + private: + friend constexpr size_t HeapDataMetadataAndReferenceCountOffset(); + + std::atomic metadata_and_reference_count_ ABSL_ATTRIBUTE_UNUSED = + 0; +}; + +inline constexpr size_t HeapDataMetadataAndReferenceCountOffset() { + return offsetof(HeapData, metadata_and_reference_count_); +} + +static_assert(HeapDataMetadataAndReferenceCountOffset() == sizeof(uintptr_t), + "Expected vptr to be at offset 0"); +static_assert(sizeof(HeapData) == sizeof(uintptr_t) * 2, + "Unexpected class size"); + +// Provides introspection for `Data`. +class Metadata final { + public: + ABSL_ATTRIBUTE_ALWAYS_INLINE static Metadata* For(Data* data) { + ABSL_ASSERT(data != nullptr); + return reinterpret_cast(data); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE static const Metadata* For(const Data* data) { + ABSL_ASSERT(data != nullptr); + return reinterpret_cast(data); + } + + Kind kind() const { + ABSL_ASSERT(!IsNull()); + return static_cast( + ((IsStoredInline() + ? *reinterpret_cast(this) + : reference_count()->load(std::memory_order_relaxed)) >> + kKindShift) & + kKindMask); + } + + DataLocality locality() const { + // We specifically do not use `IsArenaAllocated()` and + // `IsReferenceCounted()` here due to performance reasons. This code is + // called often in handle implementations. + return IsNull() ? DataLocality::kNull + : IsStoredInline() ? DataLocality::kStoredInline + : ((reference_count()->load(std::memory_order_relaxed) & + kArenaAllocated) != kArenaAllocated) + ? DataLocality::kReferenceCounted + : DataLocality::kArenaAllocated; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsNull() const { + return *reinterpret_cast(this) == 0; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsStoredInline() const { + return (*reinterpret_cast(this) & kStoredInline) == + kStoredInline; + } + + bool IsArenaAllocated() const { + return !IsNull() && !IsStoredInline() && + // We use relaxed because the top 8 bits are never mutated during + // reference counting and that is all we care about. + (reference_count()->load(std::memory_order_relaxed) & + kArenaAllocated) == kArenaAllocated; + } + + bool IsReferenceCounted() const { + return !IsNull() && !IsStoredInline() && + // We use relaxed because the top 8 bits are never mutated during + // reference counting and that is all we care about. + (reference_count()->load(std::memory_order_relaxed) & + kArenaAllocated) != kArenaAllocated; + } + + void Ref() const { + ABSL_ASSERT(IsReferenceCounted()); + const auto count = + (reference_count()->fetch_add(1, std::memory_order_relaxed)) & + kReferenceCountMask; + ABSL_ASSERT(count > 0 && count < kReferenceCountMax); + } + + bool Unref() const { + ABSL_ASSERT(IsReferenceCounted()); + const auto count = + (reference_count()->fetch_sub(1, std::memory_order_seq_cst)) & + kReferenceCountMask; + ABSL_ASSERT(count > 0 && count < kReferenceCountMax); + return count == 1; + } + + bool IsUnique() const { + ABSL_ASSERT(IsReferenceCounted()); + return ((reference_count()->fetch_add(1, std::memory_order_acquire)) & + kReferenceCountMask) == 1; + } + + bool IsTriviallyCopyable() const { + ABSL_ASSERT(IsStoredInline()); + return (*reinterpret_cast(this) & kTriviallyCopyable) == + kTriviallyCopyable; + } + + bool IsTriviallyDestructible() const { + ABSL_ASSERT(IsStoredInline()); + return (*reinterpret_cast(this) & + kTriviallyDestructible) == kTriviallyDestructible; + } + + // Used by `MemoryManager::New()`. + void SetArenaAllocated() { + reference_count()->fetch_or(kArenaAllocated, std::memory_order_relaxed); + } + + // Used by `MemoryManager::New()`. + void SetReferenceCounted() { + reference_count()->fetch_or(kReferenceCounted, std::memory_order_relaxed); + } + + private: + std::atomic* reference_count() const { + return reinterpret_cast*>( + const_cast(reinterpret_cast(this) + 1)); + } + + Metadata() = delete; +}; + +template +union alignas(Align) AnyDataStorage final { + AnyDataStorage() : pointer(0) {} + + uintptr_t pointer; + char buffer[Size]; +}; + +// Struct capable of storing data directly or a pointer to data. This is used by +// handle implementations. We use an additional bit to determine whether the +// data pointed to is arena allocated. During arena deletion, we cannot +// dereference our stored pointers as it may have already been deleted. Thus we +// need to know if it was arena allocated without dereferencing the pointer. +template +struct AnyData { + static_assert(Size >= sizeof(uintptr_t), + "Size must be at least sizeof(uintptr_t)"); + static_assert(Align >= alignof(uintptr_t), + "Align must be at least alignof(uintptr_t)"); + + using Storage = AnyDataStorage; + + Kind kind() const { + ABSL_ASSERT(!IsNull()); + return Metadata::For(get())->kind(); + } + + DataLocality locality() const { + return storage.pointer == 0 ? DataLocality::kNull + : (storage.pointer & kStoredInline) == kStoredInline + ? DataLocality::kStoredInline + : (storage.pointer & kPointerArenaAllocated) == + kPointerArenaAllocated + ? DataLocality::kArenaAllocated + : DataLocality::kReferenceCounted; + } + + bool IsNull() const { return storage.pointer == 0; } + + bool IsStoredInline() const { + return (storage.pointer & kStoredInline) == kStoredInline; + } + + bool IsArenaAllocated() const { + return (storage.pointer & kPointerArenaAllocated) == kPointerArenaAllocated; + } + + bool IsReferenceCounted() const { + return storage.pointer != 0 && + (storage.pointer & (kStoredInline | kPointerArenaAllocated)) == 0; + } + + void Ref() const { + ABSL_ASSERT(IsReferenceCounted()); + Metadata::For(get())->Ref(); + } + + bool Unref() const { + ABSL_ASSERT(IsReferenceCounted()); + return Metadata::For(get())->Unref(); + } + + bool IsUnique() const { + ABSL_ASSERT(IsReferenceCounted()); + return Metadata::For(get())->IsUnique(); + } + + bool IsTriviallyCopyable() const { + ABSL_ASSERT(IsStoredInline()); + return Metadata::For(get())->IsTriviallyCopyable(); + } + + bool IsTriviallyDestructible() const { + ABSL_ASSERT(IsStoredInline()); + return Metadata::For(get())->IsTriviallyDestructible(); + } + + // IMPORTANT: Do not use `Metadata::For(get())` unless you know what you are + // doing, instead us the method of the same name in this class. + ABSL_ATTRIBUTE_ALWAYS_INLINE Data* get() const { + return (storage.pointer & kStoredInline) == kStoredInline + ? reinterpret_cast( + const_cast(&storage.pointer)) + : reinterpret_cast(storage.pointer & kPointerMask); + } + + // Copy the bytes from other, similar to `std::memcpy`. + void CopyFrom(const AnyData& other) { + std::memcpy(&storage.buffer[0], &other.storage.buffer[0], Size); + } + + // Move the bytes from other, similar to `std::memcpy` and `std::memset`. + void MoveFrom(AnyData& other) { + std::memcpy(&storage.buffer[0], &other.storage.buffer[0], Size); + other.Clear(); + } + + template + void Destruct() { + ABSL_ASSERT(IsStoredInline()); + static_cast(get())->~T(); + } + + void Clear() { + // We only need to clear the first `sizeof(uintptr_t)` bytes as that is + // consulted to determine locality. + storage.pointer = 0; + } + + // Counterpart to `Metadata::SetArenaAllocated()` and + // `Metadata::SetReferenceCounted()`, also used by `MemoryManager`. + void ConstructHeap(const Data& data) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(&data)) >= + 2); // Assert pointer alignment results in at least the 2 least + // significant bits being unset. + storage.pointer = + reinterpret_cast(&data) | + (Metadata::For(&data)->IsArenaAllocated() ? kPointerArenaAllocated : 0); + } + + template + void ConstructInline(Args&&... args) { + ::new (&storage.buffer[0]) T(std::forward(args)...); + ABSL_ASSERT(absl::countr_zero(storage.pointer) == + 0); // Assert the least significant bit is set. + } + + Storage storage; +}; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_DATA_H_ diff --git a/base/internal/handle.h b/base/internal/handle.h new file mode 100644 index 000000000..ff7b16ac0 --- /dev/null +++ b/base/internal/handle.h @@ -0,0 +1,60 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/handle.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ + +#include + +#include "base/internal/data.h" + +namespace cel::base_internal { + +// Enumeration of different types of handles. +enum class HandleType { + kPersistent = 0, +}; + +template +struct HandleTraits; + +// Convenient aliases. +template +using PersistentHandleTraits = HandleTraits; + +template +struct HandleFactory; + +// Convenient aliases. +template +using PersistentHandleFactory = HandleFactory; + +// Non-virtual base class enforces type requirements via static_asserts for +// types used with handles. +template +struct HandlePolicy { + static_assert(!std::is_reference_v, "Handles do not support references"); + static_assert(!std::is_pointer_v, "Handles do not support pointers"); + static_assert(std::is_class_v, "Handles only support classes"); + static_assert(!std::is_volatile_v, "Handles do not support volatile"); + static_assert((std::is_base_of_v> && + !std::is_same_v>), + "Handles do not support this type"); +}; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ diff --git a/base/internal/handle.post.h b/base/internal/handle.post.h deleted file mode 100644 index 8f02c0766..000000000 --- a/base/internal/handle.post.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/handle.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ - -#include -#include - -#include "absl/base/optimization.h" -#include "base/memory_manager.h" - -namespace cel::base_internal { - -template -struct HandleFactory { - // Constructs a persistent handle whose underlying object is stored in the - // handle itself. - template - static std::enable_if_t, Persistent> - Make(Args&&... args) { - static_assert(std::is_base_of_v, - "T is not derived from Resource"); - static_assert(std::is_base_of_v, "F is not derived from T"); - return Persistent(kHandleInPlace, kInlinedResource, - std::forward(args)...); - } - - // Constructs a persistent handle whose underlying object is heap allocated - // and potentially reference counted, depending on the memory manager - // implementation. - template - static std::enable_if_t>, - Persistent> - Make(MemoryManager& memory_manager, Args&&... args) { - static_assert(std::is_base_of_v, - "T is not derived from Resource"); - static_assert(std::is_base_of_v, "F is not derived from T"); -#if defined(__cpp_lib_is_pointer_interconvertible) && \ - __cpp_lib_is_pointer_interconvertible >= 201907L - // Only available in C++20. - static_assert(std::is_pointer_interconvertible_base_of_v, - "F must be pointer interconvertible to Resource"); -#endif - auto managed_memory = memory_manager.New(std::forward(args)...); - if (ABSL_PREDICT_FALSE(managed_memory == nullptr)) { - return Persistent(); - } - bool unmanaged = GetManagedMemorySize(managed_memory) == 0; -#ifndef NDEBUG - if (!unmanaged) { - // Ensure there is no funny business going on by asserting that the size - // and alignment are the same as F. - ABSL_ASSERT(GetManagedMemorySize(managed_memory) == sizeof(F)); - ABSL_ASSERT(GetManagedMemoryAlignment(managed_memory) == alignof(F)); - // Ensure that the implementation F has correctly overriden - // SizeAndAlignment(). - auto [size, align] = static_cast(managed_memory.get()) - ->SizeAndAlignment(); - ABSL_ASSERT(size == sizeof(F)); - ABSL_ASSERT(align == alignof(F)); - // Ensures that casting F to the most base class does not require - // thunking, which occurs when using multiple inheritance. If thunking is - // used our usage of memory manager will break. If you think you need - // thunking, please consult the CEL team. - ABSL_ASSERT(static_cast( - static_cast(managed_memory.get())) == - static_cast(managed_memory.get())); - } -#endif - // Convert ManagedMemory to Persistent, transferring reference - // counting responsibility to it when applicable. `unmanaged` is true when - // no reference counting is required. - return unmanaged ? Persistent(kHandleInPlace, kUnmanagedResource, - *ManagedMemoryRelease(managed_memory)) - : Persistent(kHandleInPlace, kManagedResource, - *ManagedMemoryRelease(managed_memory)); - } - - template - static Persistent MakeUnmanaged(F& from) { - static_assert(std::is_base_of_v, "F is not derived from T"); - return Persistent(kHandleInPlace, kUnmanagedResource, from); - } -}; - -template -bool IsManagedHandle(const Persistent& handle) { - return handle.impl_.IsManaged(); -} - -template -bool IsUnmanagedHandle(const Persistent& handle) { - return handle.impl_.IsUnmanaged(); -} - -template -bool IsInlinedHandle(const Persistent& handle) { - return handle.impl_.IsInlined(); -} - -template -MemoryManager& GetMemoryManager(const Persistent& handle) { - ABSL_ASSERT(IsManagedHandle(handle)); - auto [size, align] = - static_cast(handle.operator->())->SizeAndAlignment(); - return GetMemoryManager(static_cast(handle.operator->()), size, - align); -} - -} // namespace cel::base_internal - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ diff --git a/base/internal/handle.pre.h b/base/internal/handle.pre.h deleted file mode 100644 index af0e5fdb0..000000000 --- a/base/internal/handle.pre.h +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/handle.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ - -#include -#include -#include - -#include "base/memory_manager.h" - -namespace cel { - -class Type; -class Value; - -template -class Persistent; - -class MemoryManager; - -namespace base_internal { - -class TypeHandleBase; -class ValueHandleBase; - -// Enumeration of different types of handles. -enum class HandleType { - kPersistent = 0, -}; - -template -struct HandleTraits; - -// Convenient aliases. -template -using PersistentHandleTraits = HandleTraits; - -template -struct HandleFactory; - -// Convenient aliases. -template -using PersistentHandleFactory = HandleFactory; - -struct HandleInPlace { - explicit HandleInPlace() = default; -}; - -// Disambiguation tag used to select the appropriate constructor on Persistent -// and Transient. Think std::in_place. -inline constexpr HandleInPlace kHandleInPlace{}; - -// If IsManagedHandle returns true, get a reference to the memory manager that -// is managing it. -template -MemoryManager& GetMemoryManager(const Persistent& handle); - -// Virtual base class for all classes that can be managed by handles. -class Resource { - public: - virtual ~Resource() = default; - - Resource& operator=(const Resource&) = delete; - Resource& operator=(Resource&&) = delete; - - private: - friend class cel::Type; - friend class cel::Value; - friend class TypeHandleBase; - friend class ValueHandleBase; - template - friend struct HandleFactory; - template - friend MemoryManager& GetMemoryManager(const Persistent& handle); - - Resource() = default; - Resource(const Resource&) = default; - Resource(Resource&&) = default; - - // For non-inlined resources that are reference counted, this is the result of - // `sizeof` and `alignof` for the most derived class. - virtual std::pair SizeAndAlignment() const = 0; - - // Called by TypeHandleBase, ValueHandleBase, Type, and Value for reference - // counting. - void Ref() const { - auto [size, align] = SizeAndAlignment(); - MemoryManager::Ref(this, size, align); - } - - // Called by TypeHandleBase, ValueHandleBase, Type, and Value for reference - // counting. - void Unref() const { - auto [size, align] = SizeAndAlignment(); - MemoryManager::Unref(this, size, align); - } -}; - -// Non-virtual base class for all classes that can be stored inline in handles. -// This is primarily used with SFINAE. -class ResourceInlined {}; - -template -struct InlinedResource { - explicit InlinedResource() = default; -}; - -// Disambiguation tag used to select the appropriate constructor in the handle -// implementation. Think std::in_place. -template -inline constexpr InlinedResource kInlinedResource{}; - -template -struct ManagedResource { - explicit ManagedResource() = default; -}; - -// Disambiguation tag used to select the appropriate constructor in the handle -// implementation. Think std::in_place. -template -inline constexpr ManagedResource kManagedResource{}; - -template -struct UnmanagedResource { - explicit UnmanagedResource() = default; -}; - -// Disambiguation tag used to select the appropriate constructor in the handle -// implementation. Think std::in_place. -template -inline constexpr UnmanagedResource kUnmanagedResource{}; - -// Non-virtual base class enforces type requirements via static_asserts for -// types used with handles. -template -struct HandlePolicy { - static_assert(!std::is_reference_v, "Handles do not support references"); - static_assert(!std::is_pointer_v, "Handles do not support pointers"); - static_assert(std::is_class_v, "Handles only support classes"); - static_assert(!std::is_volatile_v, "Handles do not support volatile"); - static_assert((std::is_base_of_v> && - !std::is_same_v> && - !std::is_same_v>), - "Handles do not support this type"); -}; - -template -bool IsManagedHandle(const Persistent& handle); -template -bool IsUnmanagedHandle(const Persistent& handle); -template -bool IsInlinedHandle(const Persistent& handle); - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ diff --git a/base/internal/managed_memory.cc b/base/internal/managed_memory.cc new file mode 100644 index 000000000..ce13dc587 --- /dev/null +++ b/base/internal/managed_memory.cc @@ -0,0 +1,77 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/managed_memory.h" + +#include +#include +#include + +#include "absl/base/config.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/numeric/bits.h" + +namespace cel::base_internal { + +namespace { + +size_t AlignUp(size_t size, size_t align) { +#if ABSL_HAVE_BUILTIN(__builtin_align_up) + return __builtin_align_up(size, align); +#else + return (size + align - size_t{1}) & ~(align - size_t{1}); +#endif +} + +} // namespace + +std::pair ManagedMemoryState::New( + size_t size, size_t align, ManagedMemoryDestructor destructor) { + ABSL_ASSERT(size != 0); + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. + if (ABSL_PREDICT_TRUE(align <= sizeof(ManagedMemoryState))) { + // Alignment requirements are less than the size of `ManagedMemoryState`, we + // can place `ManagedMemoryState` in front. + uint8_t* pointer = reinterpret_cast( + ::operator new(size + sizeof(ManagedMemoryState))); + ::new (pointer) ManagedMemoryState(destructor); + return {reinterpret_cast(pointer), + static_cast(pointer + sizeof(ManagedMemoryState))}; + } + // Alignment requirements are greater than the size of `ManagedMemoryState`, + // we need to place `ManagedMemoryState` at the back and pad to ensure + // `ManagedMemoryState` itself is aligned. + size_t adjusted_size = AlignUp(size, alignof(ManagedMemoryState)); + uint8_t* pointer = reinterpret_cast( + ::operator new(adjusted_size + sizeof(ManagedMemoryState))); + ::new (pointer + adjusted_size) ManagedMemoryState(destructor); + return {reinterpret_cast(pointer + adjusted_size), + static_cast(pointer)}; +} + +void ManagedMemoryState::Delete(void* pointer) { + ABSL_ASSERT(pointer != nullptr); + ABSL_ASSERT(this != pointer); + if (destructor_ != nullptr) { + (*destructor_)(pointer); + } + this->~ManagedMemoryState(); + ::operator delete(reinterpret_cast(this) < + static_cast(pointer) + ? static_cast(this) + : const_cast(pointer)); +} + +} // namespace cel::base_internal diff --git a/base/internal/managed_memory.h b/base/internal/managed_memory.h new file mode 100644 index 000000000..366a2b014 --- /dev/null +++ b/base/internal/managed_memory.h @@ -0,0 +1,373 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/numeric/bits.h" +#include "base/internal/data.h" + +namespace cel { + +class MemoryManager; + +namespace base_internal { + +template > +class ManagedMemory; + +template +T* ManagedMemoryRelease(ManagedMemory& managed_memory); + +// ManagedMemory implementation for T that is derived from Data and HeapData. +template +class ManagedMemory final { + private: + static_assert(std::is_base_of_v, + "T must be derived from HeapData"); + + public: + ManagedMemory() = default; + + explicit ManagedMemory(std::nullptr_t) : ManagedMemory() {} + + ManagedMemory(const ManagedMemory& other) : pointer_(other.pointer_) { + Ref(); + } + + template >> + ManagedMemory(const ManagedMemory& other) // NOLINT + : pointer_(other.pointer_) { + Ref(); + } + + ManagedMemory(ManagedMemory&& other) : ManagedMemory() { + std::swap(pointer_, other.pointer_); + } + + template >> + ManagedMemory(ManagedMemory&& other) // NOLINT + : ManagedMemory() { + std::swap(pointer_, other.pointer_); + } + + ~ManagedMemory() { Unref(); } + + ManagedMemory& operator=(const ManagedMemory& other) { + if (this != &other) { + other.Ref(); + Unref(); + pointer_ = other.pointer_; + } + return *this; + } + + template + std::enable_if_t, ManagedMemory&> // NOLINT + operator=(const ManagedMemory& other) { + if (this != &other) { + other.Ref(); + Unref(); + pointer_ = other.pointer_; + } + return *this; + } + + ManagedMemory& operator=(ManagedMemory&& other) { + if (this != &other) { + Unref(); + pointer_ = 0; + std::swap(pointer_, other.pointer_); + } + return *this; + } + + template + std::enable_if_t, ManagedMemory&> // NOLINT + operator=(ManagedMemory&& other) { + if (this != &other) { + Unref(); + pointer_ = 0; + std::swap(pointer_, other.pointer_); + } + return *this; + } + + T* get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return reinterpret_cast(pointer_ & kPointerMask); + } + + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(get() != nullptr); + return *get(); + } + + T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(get() != nullptr); + return get(); + } + + explicit operator bool() const { return get() != nullptr; } + + ABSL_MUST_USE_RESULT T* release() { + if (pointer_ == 0) { + return nullptr; + } + ABSL_ASSERT((pointer_ & kPointerArenaAllocated) == kPointerArenaAllocated); + T* pointer = get(); + pointer_ = 0; + return pointer; + } + + private: + friend class cel::MemoryManager; + + template + friend F* ManagedMemoryRelease(ManagedMemory& managed_memory); + + explicit ManagedMemory(T* pointer) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(pointer)) >= 2); + pointer_ = + reinterpret_cast(pointer) | + (Metadata::For(pointer)->IsArenaAllocated() ? kPointerArenaAllocated + : 0); + } + + void Ref() const { + if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0) { + Metadata::For(this)->Ref(); + } + } + + void Unref() const { + if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0 && + Metadata::For(get())->Unref()) { + delete static_cast(get()); + } + } + + uintptr_t pointer_ = 0; +}; + +template +T* ManagedMemoryRelease(ManagedMemory& managed_memory) { + T* pointer = managed_memory.get(); + managed_memory.pointer_ = 0; + return pointer; +} + +using ManagedMemoryDestructor = void (*)(void*); + +// Shared state used by `ManagedMemory` that holds the reference count +// and destructor to call when the reference count hits 0. `MemoryManager` +// places `T` and `ManagedMemoryState` in the same allocation. Whether +// `ManagedMemoryState` is before or after `T` depends on alignment requirements +// of `T`. +class ManagedMemoryState final { + public: + static std::pair New( + size_t size, size_t align, ManagedMemoryDestructor destructor); + + ManagedMemoryState() = delete; + ManagedMemoryState(const ManagedMemoryState&) = delete; + ManagedMemoryState(ManagedMemoryState&&) = delete; + ManagedMemoryState& operator=(const ManagedMemoryState&) = delete; + ManagedMemoryState& operator=(ManagedMemoryState&&) = delete; + + void Ref() { + const auto reference_count = + reference_count_.fetch_add(1, std::memory_order_relaxed); + ABSL_ASSERT(reference_count > 0); + } + + ABSL_MUST_USE_RESULT bool Unref() { + const auto reference_count = + reference_count_.fetch_sub(1, std::memory_order_seq_cst); + ABSL_ASSERT(reference_count > 0); + return reference_count == 1; + } + + void Delete(void* pointer); + + private: + explicit ManagedMemoryState(ManagedMemoryDestructor destructor) + : reference_count_(1), destructor_(destructor) {} + + mutable std::atomic reference_count_; + ManagedMemoryDestructor destructor_; +}; + +// ManagedMemory implementation for T that is not derived from Data. This is +// very similar to `std::shared_ptr`. +template +class ManagedMemory final { + public: + ManagedMemory() = default; + + explicit ManagedMemory(std::nullptr_t) : ManagedMemory() {} + + ManagedMemory(const ManagedMemory& other) + : pointer_(other.pointer_), state_(other.state_) { + Ref(); + } + + template >> + ManagedMemory(const ManagedMemory& other) // NOLINT + : pointer_(static_cast(other.pointer_)), state_(other.state_) { + Ref(); + } + + ManagedMemory(ManagedMemory&& other) : ManagedMemory() { + std::swap(pointer_, other.pointer_); + std::swap(state_, other.state_); + } + + template >> + ManagedMemory(ManagedMemory&& other) // NOLINT + : pointer_(static_cast(other.pointer_)), state_(other.state_) { + other.pointer_ = nullptr; + other.state_ = nullptr; + } + + ~ManagedMemory() { Unref(); } + + ManagedMemory& operator=(const ManagedMemory& other) { + if (this != &other) { + other.Ref(); + Unref(); + pointer_ = other.pointer_; + state_ = other.state_; + } + return *this; + } + + template + std::enable_if_t, ManagedMemory&> // NOLINT + operator=(const ManagedMemory& other) { + if (this != &other) { + other.Ref(); + Unref(); + pointer_ = static_cast(other.pointer_); + state_ = other.state_; + } + return *this; + } + + ManagedMemory& operator=(ManagedMemory&& other) { + if (this != &other) { + Unref(); + pointer_ = nullptr; + state_ = nullptr; + std::swap(pointer_, other.pointer_); + std::swap(state_, other.state_); + } + return *this; + } + + template + std::enable_if_t, ManagedMemory&> // NOLINT + operator=(ManagedMemory&& other) { + if (this != &other) { + Unref(); + pointer_ = static_cast(other.pointer_); + state_ = other.state_; + other.pointer_ = nullptr; + other.state_ = nullptr; + } + return *this; + } + + T* get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return pointer_; } + + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(get() != nullptr); + return *get(); + } + + T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(get() != nullptr); + return get(); + } + + explicit operator bool() const { return get() != nullptr; } + + ABSL_MUST_USE_RESULT T* release() { + if (pointer_ == nullptr) { + return nullptr; + } + ABSL_ASSERT(state_ == nullptr); + T* pointer = pointer_; + pointer_ = nullptr; + return pointer; + } + + private: + friend class cel::MemoryManager; + + ManagedMemory(T* pointer, ManagedMemoryState* state) + : pointer_(pointer), state_(state) {} + + void Ref() const { + if (state_ != nullptr) { + state_->Ref(); + } + } + + void Unref() const { + if (state_ != nullptr && state_->Unref()) { + state_->Delete(const_cast(static_cast(get()))); + } + } + + T* pointer_ = nullptr; + ManagedMemoryState* state_ = nullptr; +}; + +template +constexpr bool operator==(const ManagedMemory& lhs, std::nullptr_t) { + return !static_cast(lhs); +} + +template +constexpr bool operator==(std::nullptr_t, const ManagedMemory& rhs) { + return !static_cast(rhs); +} + +template +constexpr bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { + return !operator==(lhs, nullptr); +} + +template +constexpr bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { + return !operator==(nullptr, rhs); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ diff --git a/base/internal/memory_manager.pre.h b/base/internal/memory_manager.h similarity index 51% rename from base/internal/memory_manager.pre.h rename to base/internal/memory_manager.h index 741142b75..ce38458d0 100644 --- a/base/internal/memory_manager.pre.h +++ b/base/internal/memory_manager.h @@ -12,48 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// IWYU pragma: private, include "base/memory_manager.h" - #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ #include -#include - -namespace cel { - -template -class ManagedMemory; -class MemoryManager; -namespace base_internal { +namespace cel::base_internal { size_t GetPageSize(); -class Resource; - -template -constexpr size_t GetManagedMemorySize(const ManagedMemory& managed_memory); - template -constexpr size_t GetManagedMemoryAlignment( - const ManagedMemory& managed_memory); - -template -constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory); - -MemoryManager& GetMemoryManager(const void* pointer, size_t size, size_t align); - -template -class MemoryManagerDestructor final { - private: - friend class cel::MemoryManager; - - static void Destruct(void* pointer) { reinterpret_cast(pointer)->~T(); } +struct MemoryManagerDestructor final { + static void Destruct(void* pointer) { static_cast(pointer)->~T(); } }; -} // namespace base_internal - -} // namespace cel +} // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ diff --git a/base/internal/memory_manager.post.h b/base/internal/memory_manager.post.h deleted file mode 100644 index 3dec55578..000000000 --- a/base/internal/memory_manager.post.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/memory_manager.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ - -namespace cel::base_internal { - -template -constexpr size_t GetManagedMemorySize(const ManagedMemory& managed_memory) { - return managed_memory.size_; -} - -template -constexpr size_t GetManagedMemoryAlignment( - const ManagedMemory& managed_memory) { - return managed_memory.align_; -} - -template -constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory) { - // Like ManagedMemory::release except there is no assert. For use during - // handle creation. - T* ptr = managed_memory.ptr_; - managed_memory.ptr_ = nullptr; - managed_memory.size_ = managed_memory.align_ = 0; - return ptr; -} - -inline MemoryManager& GetMemoryManager(const void* pointer, size_t size, - size_t align) { - return MemoryManager::Get(pointer, size, align); -} - -} // namespace cel::base_internal - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ diff --git a/base/internal/type.pre.h b/base/internal/type.h similarity index 51% rename from base/internal/type.pre.h rename to base/internal/type.h index 80ecacc54..8735dc26b 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.h @@ -14,12 +14,15 @@ // IWYU pragma: private, include "base/type.h" -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ #include #include "base/handle.h" +#include "base/internal/data.h" +#include "base/kind.h" +#include "internal/casts.h" #include "internal/rtti.h" namespace cel { @@ -29,28 +32,25 @@ class StructType; namespace base_internal { -class TypeHandleBase; -template -class TypeHandle; - -// Convenient aliases. -using PersistentTypeHandle = TypeHandle; - -// As all objects should be aligned to at least 4 bytes, we can use the lower -// two bits for our own purposes. -inline constexpr uintptr_t kTypeHandleUnmanaged = 1 << 0; -inline constexpr uintptr_t kTypeHandleReserved = 1 << 1; -inline constexpr uintptr_t kTypeHandleBits = - kTypeHandleUnmanaged | kTypeHandleReserved; -inline constexpr uintptr_t kTypeHandleMask = ~kTypeHandleBits; +class PersistentTypeHandle; class ListTypeImpl; class MapTypeImpl; +template +class SimpleType; +template +class SimpleValue; + internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type); internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); +inline constexpr size_t kTypeInlineSize = sizeof(void*); +inline constexpr size_t kTypeInlineAlign = alignof(void*); + +struct AnyType final : public AnyData {}; + } // namespace base_internal } // namespace cel @@ -63,14 +63,13 @@ internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); template class Persistent; \ template class Persistent -#define CEL_INTERNAL_DECLARE_TYPE(base, derived) \ - private: \ - friend class ::cel::base_internal::TypeHandleBase; \ - \ - static bool Is(const ::cel::Type& type); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ +#define CEL_INTERNAL_DECLARE_TYPE(base, derived) \ + public: \ + static bool Is(const ::cel::Type& type); \ + \ + private: \ + friend class ::cel::base_internal::PersistentTypeHandle; \ + \ ::cel::internal::TypeInfo TypeId() const override; #define CEL_INTERNAL_IMPLEMENT_TYPE(base, derived) \ @@ -81,23 +80,12 @@ internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); bool derived::Is(const ::cel::Type& type) { \ return type.kind() == ::cel::Kind::k##base && \ ::cel::base_internal::Get##base##TypeTypeId( \ - ::cel::internal::down_cast(type)) == \ + static_cast(type)) == \ ::cel::internal::TypeId(); \ } \ \ - ::std::pair<::std::size_t, ::std::size_t> derived::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #derived); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(derived), \ - alignof(derived)); \ - } \ - \ ::cel::internal::TypeInfo derived::TypeId() const { \ return ::cel::internal::TypeId(); \ } -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ diff --git a/base/internal/type.post.h b/base/internal/type.post.h deleted file mode 100644 index 215493351..000000000 --- a/base/internal/type.post.h +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/type.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ - -#include -#include -#include -#include - -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "absl/hash/hash.h" -#include "absl/numeric/bits.h" -#include "base/handle.h" -#include "internal/rtti.h" - -namespace cel { - -namespace base_internal { - -// Base implementation of persistent and transient handles for types. This -// contains implementation details shared among both, but is never used -// directly. The derived classes are responsible for defining appropriate -// constructors and assignments. -class TypeHandleBase { - public: - constexpr TypeHandleBase() = default; - - // Used by derived classes to bypass default construction to perform their own - // construction. - explicit TypeHandleBase(HandleInPlace) {} - - // Called by `Transient` and `Persistent` to implement the same operator. They - // will handle enforcing const correctness. - Type& operator*() const { return get(); } - - // Called by `Transient` and `Persistent` to implement the same operator. They - // will handle enforcing const correctness. - Type* operator->() const { return std::addressof(get()); } - - // Called by internal accessors `base_internal::IsXHandle`. - constexpr bool IsManaged() const { - return (rep_ & kTypeHandleUnmanaged) == 0; - } - - // Called by internal accessors `base_internal::IsXHandle`. - constexpr bool IsUnmanaged() const { - return (rep_ & kTypeHandleUnmanaged) != 0; - } - - // Called by internal accessors `base_internal::IsXHandle`. - constexpr bool IsInlined() const { return false; } - - // Called by `Transient` and `Persistent` to implement the same function. - template - bool Is() const { - return static_cast(*this) && T::Is(static_cast(**this)); - } - - // Called by `Transient` and `Persistent` to implement the same operator. - explicit operator bool() const { return (rep_ & kTypeHandleMask) != 0; } - - // Called by `Transient` and `Persistent` to implement the same operator. - friend bool operator==(const TypeHandleBase& lhs, const TypeHandleBase& rhs) { - if (static_cast(lhs) != static_cast(rhs) || - !static_cast(lhs)) { - return false; - } - const Type& lhs_type = lhs.get(); - const Type& rhs_type = rhs.get(); - return lhs_type.Equals(rhs_type); - } - - // Called by `Transient` and `Persistent` to implement std::swap. - friend void swap(TypeHandleBase& lhs, TypeHandleBase& rhs) { - std::swap(lhs.rep_, rhs.rep_); - } - - template - friend H AbslHashValue(H state, const TypeHandleBase& handle) { - if (ABSL_PREDICT_TRUE(static_cast(handle))) { - handle.get().HashValue(absl::HashState::Create(&state)); - } - return state; - } - - private: - template - friend class TypeHandle; - - void Unref() const { - if ((rep_ & kTypeHandleUnmanaged) == 0) { - get().Unref(); - } - } - - uintptr_t Ref() const { - if ((rep_ & kTypeHandleUnmanaged) == 0) { - get().Ref(); - } - return rep_; - } - - Type& get() const { return *reinterpret_cast(rep_ & kTypeHandleMask); } - - // There are no inlined types, so we represent everything as a pointer and use - // tagging to differentiate between reference counted and arena-allocated. - uintptr_t rep_ = kTypeHandleUnmanaged; -}; - -// All methods are called by `Persistent`. -template <> -class TypeHandle final : public TypeHandleBase { - public: - constexpr TypeHandle() = default; - - TypeHandle(const PersistentTypeHandle& other) { rep_ = other.Ref(); } - - TypeHandle(PersistentTypeHandle&& other) { - rep_ = other.rep_; - other.rep_ = kTypeHandleUnmanaged; - } - - template - TypeHandle(UnmanagedResource, F& from) : TypeHandleBase(kHandleInPlace) { - uintptr_t rep = reinterpret_cast( - static_cast(static_cast(std::addressof(from)))); - ABSL_ASSERT(absl::countr_zero(rep) >= - 2); // Verify the lower 2 bits are available. - rep_ = rep | kTypeHandleUnmanaged; - } - - template - TypeHandle(ManagedResource, F& from) : TypeHandleBase(kHandleInPlace) { - uintptr_t rep = reinterpret_cast( - static_cast(static_cast(std::addressof(from)))); - ABSL_ASSERT(absl::countr_zero(rep) >= - 2); // Verify the lower 2 bits are available. - rep_ = rep; - } - - ~TypeHandle() { Unref(); } - - TypeHandle& operator=(const PersistentTypeHandle& other) { - Unref(); - rep_ = other.Ref(); - return *this; - } - - TypeHandle& operator=(PersistentTypeHandle&& other) { - Unref(); - rep_ = other.rep_; - other.rep_ = kTypeHandleUnmanaged; - return *this; - } -}; - -// Specialization for Type providing the implementation to `Persistent`. -template <> -struct HandleTraits { - using handle_type = TypeHandle; -}; - -// Partial specialization for `Persistent` for all classes derived from Type. -template -struct HandleTraits< - HandleType::kPersistent, T, - std::enable_if_t<(std::is_base_of_v && !std::is_same_v)>> - final : public HandleTraits {}; - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ diff --git a/base/internal/value.h b/base/internal/value.h new file mode 100644 index 000000000..d8fd305ce --- /dev/null +++ b/base/internal/value.h @@ -0,0 +1,159 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "base/handle.h" +#include "base/internal/data.h" +#include "base/types/enum_type.h" +#include "internal/rtti.h" + +namespace cel { + +class BytesValue; +class StringValue; +class StructValue; +class ListValue; +class MapValue; + +namespace base_internal { + +template +class SimpleValue; + +class PersistentValueHandle; + +internal::TypeInfo GetStructValueTypeId(const StructValue& struct_value); + +internal::TypeInfo GetListValueTypeId(const ListValue& list_value); + +internal::TypeInfo GetMapValueTypeId(const MapValue& map_value); + +static_assert(std::is_trivially_copyable_v, + "absl::Duration must be trivially copyable."); +static_assert(std::is_trivially_destructible_v, + "absl::Duration must be trivially destructible."); + +static_assert(std::is_trivially_copyable_v, + "absl::Time must be trivially copyable."); +static_assert(std::is_trivially_destructible_v, + "absl::Time must be trivially destructible."); + +static_assert(std::is_trivially_copyable_v, + "absl::string_view must be trivially copyable."); +static_assert(std::is_trivially_destructible_v, + "absl::string_view must be trivially destructible."); + +struct InlineValue final { + uintptr_t vptr; + union { + bool bool_value; + int64_t int64_value; + uint64_t uint64_value; + double double_value; + uintptr_t pointer_value; + absl::Duration duration_value; + absl::Time time_value; + absl::Status status_value; + absl::Cord cord_value; + absl::string_view string_value; + struct { + Persistent type; + int64_t number; + } enum_value; + }; +}; + +inline constexpr size_t kValueInlineSize = sizeof(InlineValue); +inline constexpr size_t kValueInlineAlign = alignof(InlineValue); + +static_assert(kValueInlineSize <= 32, + "Size of an inline value should be less than 32 bytes."); +static_assert(kValueInlineAlign <= alignof(std::max_align_t), + "Alignment of an inline value should not be overaligned."); + +struct AnyValue final : public AnyData {}; + +class InlinedCordBytesValue; +class InlinedStringViewBytesValue; +class StringBytesValue; +class InlinedCordStringValue; +class InlinedStringViewStringValue; +class StringStringValue; + +using StringValueRep = + absl::variant>; +using BytesValueRep = + absl::variant>; + +} // namespace base_internal + +namespace interop_internal { + +base_internal::StringValueRep GetStringValueRep( + const Persistent& value); +base_internal::BytesValueRep GetBytesValueRep( + const Persistent& value); + +} // namespace interop_internal + +} // namespace cel + +#define CEL_INTERNAL_VALUE_DECL(name) \ + extern template class Persistent; \ + extern template class Persistent + +#define CEL_INTERNAL_VALUE_IMPL(name) \ + template class Persistent; \ + template class Persistent + +#define CEL_INTERNAL_DECLARE_VALUE(base, derived) \ + public: \ + static bool Is(const ::cel::Value& value); \ + \ + private: \ + friend class ::cel::base_internal::PersistentValueHandle; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +#define CEL_INTERNAL_IMPLEMENT_VALUE(base, derived) \ + static_assert(::std::is_base_of_v<::cel::base##Value, derived>, \ + #derived " must inherit from cel::" #base "Value"); \ + static_assert(!::std::is_abstract_v, "this must not be abstract"); \ + \ + bool derived::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::k##base && \ + ::cel::base_internal::Get##base##ValueTypeId( \ + static_cast(value)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::cel::internal::TypeInfo derived::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ diff --git a/base/internal/value.post.h b/base/internal/value.post.h deleted file mode 100644 index f6753c543..000000000 --- a/base/internal/value.post.h +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/value.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "absl/hash/hash.h" -#include "absl/strings/string_view.h" -#include "base/handle.h" -#include "internal/casts.h" - -namespace cel { - -namespace base_internal { - -struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffsetBase { - virtual ~CheckVptrOffsetBase() = default; - - virtual void Member() const {} -}; - -// Class used to assert the object memory layout for vptr at compile time, -// otherwise it is unused. -struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffset final - : public CheckVptrOffsetBase { - uintptr_t member; -}; - -// Ensure the hidden vptr is stored at the beginning of the object. See -// ValueHandleData for more information. -static_assert(offsetof(CheckVptrOffset, member) == sizeof(void*), - "CEL C++ requires a compiler that stores the vptr as a hidden " - "member at the beginning of the object. If this static_assert " - "fails, please reach out to the CEL team."); - -// Union of all known inlinable values. -union ValueHandleData final { - // As asserted above, we rely on the fact that the compiler stores the vptr as - // a hidden member at the beginning of the object. We then re-use the first 2 - // bits to differentiate between an inlined value (both 0), a heap allocated - // reference counted value, or a arena allocated value. - void* vptr; - alignas(std::max_align_t) char padding[32]; -}; - -// Base implementation of persistent and transient handles for values. This -// contains implementation details shared among both, but is never used -// directly. The derived classes are responsible for defining appropriate -// constructors and assignments. -class ValueHandleBase { - public: - ValueHandleBase() { Reset(); } - - // Used by derived classes to bypass default construction to perform their own - // construction. - explicit ValueHandleBase(HandleInPlace) {} - - // Called by `Transient` and `Persistent` to implement the same operator. They - // will handle enforcing const correctness. - Value& operator*() const { return get(); } - - // Called by `Transient` and `Persistent` to implement the same operator. They - // will handle enforcing const correctness. - Value* operator->() const { return std::addressof(get()); } - - // Called by internal accessors `base_internal::IsXHandle`. - bool IsManaged() const { return (vptr() & kValueHandleManaged) != 0; } - - // Called by internal accessors `base_internal::IsXHandle`. - bool IsUnmanaged() const { return (vptr() & kValueHandleUnmanaged) != 0; } - - // Called by internal accessors `base_internal::IsXHandle`. - bool IsInlined() const { return (vptr() & kValueHandleBits) == 0; } - - // Called by `Transient` and `Persistent` to implement the same function. - template - bool Is() const { - // Tests that this is not an empty handle and then dereferences the handle - // calling the RTTI-like implementation T::Is which takes `const Value&`. - return static_cast(*this) && T::Is(static_cast(**this)); - } - - // Called by `Transient` and `Persistent` to implement the same operator. - explicit operator bool() const { return (vptr() & kValueHandleMask) != 0; } - - // Called by `Transient` and `Persistent` to implement the same operator. - friend bool operator==(const ValueHandleBase& lhs, - const ValueHandleBase& rhs) { - if (static_cast(lhs) != static_cast(rhs) || - !static_cast(lhs)) { - return false; - } - const Value& lhs_value = lhs.get(); - const Value& rhs_value = rhs.get(); - return lhs_value.Equals(rhs_value); - } - - // Called by `Transient` and `Persistent` to implement std::swap. - friend void swap(ValueHandleBase& lhs, ValueHandleBase& rhs) { - if (lhs.empty_or_not_inlined() && rhs.empty_or_not_inlined()) { - // Both `lhs` and `rhs` are simple pointers. Just swap them. - std::swap(lhs.data_.vptr, rhs.data_.vptr); - return; - } - ValueHandleBase tmp; - Move(lhs, tmp); - Move(rhs, lhs); - Move(tmp, rhs); - } - - template - friend H AbslHashValue(H state, const ValueHandleBase& handle) { - if (ABSL_PREDICT_TRUE(static_cast(handle))) { - handle.get().HashValue(absl::HashState::Create(&state)); - } - return state; - } - - private: - template - friend class ValueHandle; - - // Resets the state to the same as the default constructor. Does not perform - // any destruction of existing content. - void Reset() { data_.vptr = reinterpret_cast(kValueHandleUnmanaged); } - - void Unref() const { - ABSL_ASSERT(reffed()); - reinterpret_cast(vptr() & kValueHandleMask)->Unref(); - } - - void Ref() const { - ABSL_ASSERT(reffed()); - reinterpret_cast(vptr() & kValueHandleMask)->Ref(); - } - - Value& get() const { - return *(inlined() - ? reinterpret_cast(const_cast(&data_.vptr)) - : reinterpret_cast(vptr() & kValueHandleMask)); - } - - bool empty() const { return !static_cast(*this); } - - // Does the stored data represent an inlined value? - bool inlined() const { return (vptr() & kValueHandleBits) == 0; } - - // Does the stored data represent a non-null inlined value? - bool not_empty_and_inlined() const { - return (vptr() & kValueHandleBits) == 0 && (vptr() & kValueHandleMask) != 0; - } - - // Does the stored data represent null, heap allocated reference counted, or - // arena allocated value? - bool empty_or_not_inlined() const { - return (vptr() & kValueHandleBits) != 0 || (vptr() & kValueHandleMask) == 0; - } - - // Does the stored data required reference counting? - bool reffed() const { return (vptr() & kValueHandleManaged) != 0; } - - uintptr_t vptr() const { return reinterpret_cast(data_.vptr); } - - static void Copy(const ValueHandleBase& from, ValueHandleBase& to) { - if (from.empty_or_not_inlined()) { - // `from` is a simple pointer, just copy it. - to.data_.vptr = from.data_.vptr; - } else { - from.get().CopyTo(*reinterpret_cast(&to.data_.vptr)); - } - } - - static void Move(ValueHandleBase& from, ValueHandleBase& to) { - if (from.empty_or_not_inlined()) { - // `from` is a simple pointer, just swap it. - std::swap(from.data_.vptr, to.data_.vptr); - } else { - from.get().MoveTo(*reinterpret_cast(&to.data_.vptr)); - DestructInlined(from); - } - } - - static void DestructInlined(ValueHandleBase& handle) { - ABSL_ASSERT(!handle.empty_or_not_inlined()); - handle.get().~Value(); - handle.Reset(); - } - - ValueHandleData data_; -}; - -// All methods are called by `Persistent`. -template <> -class ValueHandle final : public ValueHandleBase { - private: - using Base = ValueHandleBase; - - public: - ValueHandle() = default; - - template - explicit ValueHandle(InlinedResource, Args&&... args) - : ValueHandleBase(kHandleInPlace) { - static_assert(sizeof(T) <= sizeof(data_.padding), - "T cannot be inlined in Handle"); - static_assert(alignof(T) <= alignof(data_.padding), - "T cannot be inlined in Handle"); - ::new (const_cast(static_cast(&data_.padding))) - T(std::forward(args)...); - ABSL_ASSERT(absl::countr_zero(vptr()) >= - 2); // Verify the lower 2 bits are available. - } - - template - ValueHandle(UnmanagedResource, F& from) : ValueHandleBase(kHandleInPlace) { - uintptr_t vptr = reinterpret_cast( - static_cast(static_cast(std::addressof(from)))); - ABSL_ASSERT(absl::countr_zero(vptr) >= - 2); // Verify the lower 2 bits are available. - data_.vptr = reinterpret_cast(vptr | kValueHandleUnmanaged); - } - - template - ValueHandle(ManagedResource, F& from) : ValueHandleBase(kHandleInPlace) { - uintptr_t vptr = reinterpret_cast( - static_cast(static_cast(std::addressof(from)))); - ABSL_ASSERT(absl::countr_zero(vptr) >= - 2); // Verify the lower 2 bits are available. - data_.vptr = reinterpret_cast(vptr | kValueHandleManaged); - } - - ValueHandle(const PersistentValueHandle& other) : ValueHandle() { - Base::Copy(other, *this); - if (reffed()) { - Ref(); - } - } - - ValueHandle(PersistentValueHandle&& other) : ValueHandle() { - Base::Move(other, *this); - } - - ~ValueHandle() { - if (not_empty_and_inlined()) { - DestructInlined(*this); - } else if (reffed()) { - Unref(); - } - } - - ValueHandle& operator=(const PersistentValueHandle& other) { - if (not_empty_and_inlined()) { - DestructInlined(*this); - } else if (reffed()) { - Unref(); - } - Base::Copy(other, *this); - if (reffed()) { - Ref(); - } - return *this; - } - - ValueHandle& operator=(PersistentValueHandle&& other) { - if (not_empty_and_inlined()) { - DestructInlined(*this); - } else if (reffed()) { - Unref(); - Reset(); - } - Base::Move(other, *this); - return *this; - } -}; -// Specialization for Value providing the implementation to `Persistent`. -template <> -struct HandleTraits { - using handle_type = ValueHandle; -}; - -// Partial specialization for `Persistent` for all classes derived from Value. -template -struct HandleTraits && - !std::is_same_v)>> - final : public HandleTraits {}; - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h deleted file mode 100644 index 3cda23dc8..000000000 --- a/base/internal/value.pre.h +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "base/value.h" - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ - -#include -#include -#include -#include - -#include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "absl/types/variant.h" -#include "base/handle.h" -#include "internal/rtti.h" - -namespace cel { - -class EnumValue; -class StructValue; -class ListValue; -class MapValue; - -namespace base_internal { - -class ValueHandleBase; -template -class ValueHandle; - -// Convenient aliases. -using PersistentValueHandle = ValueHandle; - -// As all objects should be aligned to at least 4 bytes, we can use the lower -// two bits for our own purposes. -inline constexpr uintptr_t kValueHandleManaged = 1 << 0; -inline constexpr uintptr_t kValueHandleUnmanaged = 1 << 1; -inline constexpr uintptr_t kValueHandleBits = - kValueHandleManaged | kValueHandleUnmanaged; -inline constexpr uintptr_t kValueHandleMask = ~kValueHandleBits; - -internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value); - -internal::TypeInfo GetStructValueTypeId(const StructValue& struct_value); - -internal::TypeInfo GetListValueTypeId(const ListValue& list_value); - -internal::TypeInfo GetMapValueTypeId(const MapValue& map_value); - -class InlinedCordBytesValue; -class InlinedStringViewBytesValue; -class StringBytesValue; -class ExternalDataBytesValue; -class InlinedCordStringValue; -class InlinedStringViewStringValue; -class StringStringValue; -class ExternalDataStringValue; - -// Type erased state capable of holding a pointer to remote storage or storing -// objects less than two pointers in size inline. -union ExternalDataReleaserState final { - void* remote; - alignas(alignof(std::max_align_t)) char local[sizeof(void*) * 2]; -}; - -// Function which deletes the object referenced by ExternalDataReleaserState. -using ExternalDataReleaserDeleter = void(ExternalDataReleaserState* state); - -template -void LocalExternalDataReleaserDeleter(ExternalDataReleaserState* state) { - reinterpret_cast(&state->local)->~Releaser(); -} - -template -void RemoteExternalDataReleaserDeleter(ExternalDataReleaserState* state) { - ::delete reinterpret_cast(state->remote); -} - -// Function which invokes the object referenced by ExternalDataReleaserState. -using ExternalDataReleaseInvoker = - void(ExternalDataReleaserState* state) noexcept; - -template -void LocalExternalDataReleaserInvoker( - ExternalDataReleaserState* state) noexcept { - (*reinterpret_cast(&state->local))(); -} - -template -void RemoteExternalDataReleaserInvoker( - ExternalDataReleaserState* state) noexcept { - (*reinterpret_cast(&state->remote))(); -} - -struct ExternalDataReleaser final { - ExternalDataReleaser() = delete; - - template - explicit ExternalDataReleaser(Releaser&& releaser) { - using DecayedReleaser = std::decay_t; - if constexpr (sizeof(DecayedReleaser) <= sizeof(void*) * 2 && - alignof(DecayedReleaser) <= alignof(std::max_align_t)) { - // Object meets size and alignment constraints, will be stored - // inline in ExternalDataReleaserState.local. - ::new (static_cast(&state.local)) - DecayedReleaser(std::forward(releaser)); - invoker = LocalExternalDataReleaserInvoker; - if constexpr (std::is_trivially_destructible_v) { - // Object is trivially destructable, no need to call destructor at all. - deleter = nullptr; - } else { - deleter = LocalExternalDataReleaserDeleter; - } - } else { - // Object does not meet size and alignment constraints, allocate on the - // heap and store pointer in ExternalDataReleaserState::remote. inline in - // ExternalDataReleaserState::local. - state.remote = ::new DecayedReleaser(std::forward(releaser)); - invoker = RemoteExternalDataReleaserInvoker; - deleter = RemoteExternalDataReleaserDeleter; - } - } - - ExternalDataReleaser(const ExternalDataReleaser&) = delete; - - ExternalDataReleaser(ExternalDataReleaser&&) = delete; - - ~ExternalDataReleaser() { - (*invoker)(&state); - if (deleter != nullptr) { - (*deleter)(&state); - } - } - - ExternalDataReleaser& operator=(const ExternalDataReleaser&) = delete; - - ExternalDataReleaser& operator=(ExternalDataReleaser&&) = delete; - - ExternalDataReleaserState state; - ExternalDataReleaserDeleter* deleter; - ExternalDataReleaseInvoker* invoker; -}; - -// Utility class encompassing a contiguous array of data which a function that -// must be called when the data is no longer needed. -struct ExternalData final { - ExternalData() = delete; - - ExternalData(const void* data, size_t size, - std::unique_ptr releaser) - : data(data), size(size), releaser(std::move(releaser)) {} - - ExternalData(const ExternalData&) = delete; - - ExternalData(ExternalData&&) noexcept = default; - - ExternalData& operator=(const ExternalData&) = delete; - - ExternalData& operator=(ExternalData&&) noexcept = default; - - const void* data; - size_t size; - std::unique_ptr releaser; -}; - -using StringValueRep = - absl::variant>; -using BytesValueRep = - absl::variant>; - -} // namespace base_internal - -} // namespace cel - -#define CEL_INTERNAL_VALUE_DECL(name) \ - extern template class Persistent; \ - extern template class Persistent - -#define CEL_INTERNAL_VALUE_IMPL(name) \ - template class Persistent; \ - template class Persistent - -// Both are equivalent to std::construct_at implementation from C++20. -#define CEL_INTERNAL_VALUE_COPY_TO(type, src, dest) \ - ::new (const_cast( \ - static_cast(std::addressof(dest)))) type(src) -#define CEL_INTERNAL_VALUE_MOVE_TO(type, src, dest) \ - ::new (const_cast(static_cast( \ - std::addressof(dest)))) type(std::move(src)) - -#define CEL_INTERNAL_DECLARE_VALUE(base, derived) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - static bool Is(const ::cel::Value& value); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; - -#define CEL_INTERNAL_IMPLEMENT_VALUE(base, derived) \ - static_assert(::std::is_base_of_v<::cel::base##Value, derived>, \ - #derived " must inherit from cel::" #base "Value"); \ - static_assert(!::std::is_abstract_v, "this must not be abstract"); \ - \ - bool derived::Is(const ::cel::Value& value) { \ - return value.kind() == ::cel::Kind::k##base && \ - ::cel::base_internal::Get##base##ValueTypeId( \ - ::cel::internal::down_cast( \ - value)) == ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> derived::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #derived); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(derived), \ - alignof(derived)); \ - } \ - \ - ::cel::internal::TypeInfo derived::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ diff --git a/base/kind.h b/base/kind.h index e7f6169a0..a32c432f8 100644 --- a/base/kind.h +++ b/base/kind.h @@ -15,12 +15,14 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ #define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ +#include + #include "absl/base/attributes.h" #include "absl/strings/string_view.h" namespace cel { -enum class Kind { +enum class Kind : uint8_t { kNullType = 0, kError, kDyn, @@ -38,6 +40,9 @@ enum class Kind { kList, kMap, kStruct, + + // INTERNAL: Do not exceed 127. Implementation details rely on the fact that + // we can store `Kind` using 7 bits. }; ABSL_ATTRIBUTE_PURE_FUNCTION absl::string_view KindToString(Kind kind); diff --git a/base/managed_memory.h b/base/managed_memory.h new file mode 100644 index 000000000..ae8628a1b --- /dev/null +++ b/base/managed_memory.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ + +#include + +#include "base/internal/managed_memory.h" + +namespace cel { + +// `ManagedMemory` is a smart pointer which ensures any applicable object +// destructors and deallocation are eventually performed. Copying does not +// actually copy the underlying T, instead a pointer is copied and optionally +// reference counted. Moving does not actually move the underlying T, instead a +// pointer is moved. +// +// TODO(issues/5): consider feature parity with std::unique_ptr +template +using ManagedMemory = base_internal::ManagedMemory; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ diff --git a/base/memory_manager.cc b/base/memory_manager.cc index 0f0d40522..9c2bd2341 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -42,8 +42,6 @@ #include #include -#include "absl/base/attributes.h" -#include "absl/base/call_once.h" #include "absl/base/config.h" #include "absl/base/dynamic_annotations.h" #include "absl/base/macros.h" @@ -56,55 +54,6 @@ namespace cel { namespace { -class GlobalMemoryManager final : public MemoryManager { - public: - GlobalMemoryManager() : MemoryManager() {} - - private: - AllocationResult Allocate(size_t size, size_t align) override { - void* pointer; - if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { - pointer = ::operator new(size, std::nothrow); - } else { - pointer = ::operator new(size, static_cast(align), - std::nothrow); - } - return {pointer}; - } - - void Deallocate(void* pointer, size_t size, size_t align) override { - if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { - ::operator delete(pointer, size); - } else { - ::operator delete(pointer, size, static_cast(align)); - } - } -}; - -struct ControlBlock final { - constexpr explicit ControlBlock(MemoryManager* memory_manager) - : refs(1), memory_manager(memory_manager) {} - - ControlBlock(const ControlBlock&) = delete; - ControlBlock(ControlBlock&&) = delete; - ControlBlock& operator=(const ControlBlock&) = delete; - ControlBlock& operator=(ControlBlock&&) = delete; - - mutable std::atomic refs; - MemoryManager* memory_manager; - - void Ref() const { - const auto cnt = refs.fetch_add(1, std::memory_order_relaxed); - ABSL_ASSERT(cnt >= 1); - } - - bool Unref() const { - const auto cnt = refs.fetch_sub(1, std::memory_order_acq_rel); - ABSL_ASSERT(cnt >= 1); - return cnt == 1; - } -}; - uintptr_t AlignUp(uintptr_t size, size_t align) { ABSL_ASSERT(size != 0); ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. @@ -122,94 +71,6 @@ T* AlignUp(T* pointer, size_t align) { AlignUp(reinterpret_cast(pointer), align)); } -inline constexpr size_t kControlBlockSize = sizeof(ControlBlock); -inline constexpr size_t kControlBlockAlign = alignof(ControlBlock); - -// When not using arena-based allocation, MemoryManager needs to embed a pointer -// to itself in the allocation block so the same memory manager can be used to -// deallocate. When the alignment requested is less than or equal to that of the -// native pointer alignment it is embedded at the beginning of the allocated -// block, otherwise its at the end. -// -// For allocations requiring alignment greater than alignof(ControlBlock) we -// cannot place the control block in front as it would change the alignment of -// T, resulting in undefined behavior. For allocations requiring less alignment -// than alignof(ControlBlock), we should not place the control back in back as -// it would waste memory due to having to pad the allocation to ensure -// ControlBlock itself is aligned. -enum class Placement { - kBefore = 0, - kAfter, -}; - -constexpr Placement GetPlacement(size_t align) { - return ABSL_PREDICT_TRUE(align <= kControlBlockAlign) ? Placement::kBefore - : Placement::kAfter; -} - -void* AdjustAfterAllocation(MemoryManager* memory_manager, void* pointer, - size_t size, size_t align) { - switch (GetPlacement(align)) { - case Placement::kBefore: - // Store the pointer to the memory manager at the beginning of the - // allocated block and adjust the pointer to immediately after it. - ::new (pointer) ControlBlock(memory_manager); - pointer = static_cast(static_cast(pointer) + - kControlBlockSize); - break; - case Placement::kAfter: - // Store the pointer to the memory manager at the end of the allocated - // block. Don't need to adjust the pointer. - ::new (static_cast(static_cast(pointer) + size - - kControlBlockSize)) - ControlBlock(memory_manager); - break; - } - return pointer; -} - -void* AdjustForDeallocation(void* pointer, size_t align) { - switch (GetPlacement(align)) { - case Placement::kBefore: - // We need to back up kPointerSize as that is actually the original - // allocated address returned from `Allocate`. - pointer = static_cast(static_cast(pointer) - - kControlBlockSize); - break; - case Placement::kAfter: - // No need to do anything. - break; - } - return pointer; -} - -ControlBlock* GetControlBlock(const void* pointer, size_t size, size_t align) { - ControlBlock* control_block; - switch (GetPlacement(align)) { - case Placement::kBefore: - // Embedded reference count block is located just before `pointer`. - control_block = reinterpret_cast( - static_cast(const_cast(pointer)) - - kControlBlockSize); - break; - case Placement::kAfter: - // Embedded reference count block is located at `pointer + size - - // kControlBlockSize`. - control_block = reinterpret_cast( - static_cast(const_cast(pointer)) + size - - kControlBlockSize); - break; - } - return control_block; -} - -size_t AdjustAllocationSize(size_t size, size_t align) { - if (GetPlacement(align) == Placement::kAfter) { - size = AlignUp(size, kControlBlockAlign); - } - return size + kControlBlockSize; -} - struct ArenaBlock final { // The base pointer of the virtual memory, always points to the start of a // page. @@ -303,12 +164,12 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { } private: - AllocationResult Allocate(size_t size, size_t align) override { + void* Allocate(size_t size, size_t align) override { auto page_size = base_internal::GetPageSize(); if (align > page_size) { // Just, no. We refuse anything that requests alignment over the system // page size. - return AllocationResult{nullptr}; + return nullptr; } absl::MutexLock lock(&mutex_); bool bridge_gap = false; @@ -319,7 +180,7 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { // large enough. auto maybe_block = ArenaBlockAllocate(AlignUp(size, page_size)); if (!maybe_block.has_value()) { - return AllocationResult{nullptr}; + return nullptr; } blocks_.push_back(std::move(maybe_block).value()); } else { @@ -330,7 +191,7 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { auto maybe_block = ArenaBlockAllocate(AlignUp(size, page_size), last_block.end); if (!maybe_block.has_value()) { - return AllocationResult{nullptr}; + return nullptr; } bridge_gap = last_block.end == maybe_block.value().begin; blocks_.push_back(std::move(maybe_block).value()); @@ -345,9 +206,9 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { size_t remaining = second_last_block.remaining(); void* pointer = second_last_block.Allocate(remaining); blocks_.back().Allocate(size - remaining); - return AllocationResult{pointer}; + return pointer; } - return AllocationResult{blocks_.back().Allocate(size)}; + return blocks_.back().Allocate(size); } void OwnDestructor(void* pointer, void (*destruct)(void*)) override { @@ -364,6 +225,27 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { } // namespace +class GlobalMemoryManager final : public MemoryManager { + public: + GlobalMemoryManager() : MemoryManager(false) {} + + private: + // Never actually called by `MemoryManager`. + void* Allocate(size_t size, size_t align) override { + static_cast(size); + static_cast(align); + ABSL_INTERNAL_UNREACHABLE; + return nullptr; + } + + // Never actually called by `MemoryManager`. + void OwnDestructor(void* pointer, void (*destructor)(void*)) override { + static_cast(pointer); + static_cast(destructor); + ABSL_INTERNAL_UNREACHABLE; + } +}; + namespace base_internal { // Returns the platforms page size. When requesting vitual memory from the @@ -397,87 +279,6 @@ MemoryManager& MemoryManager::Global() { return *instance; } -void* MemoryManager::AllocateInternal(size_t& size, size_t& align) { - ABSL_ASSERT(size != 0); - ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. - size_t adjusted_size = size; - if (!allocation_only_) { - adjusted_size = AdjustAllocationSize(adjusted_size, align); - } - auto [pointer] = Allocate(adjusted_size, align); - if (ABSL_PREDICT_TRUE(pointer != nullptr) && !allocation_only_) { - pointer = AdjustAfterAllocation(this, pointer, adjusted_size, align); - } else { - // 0 is not a valid result of sizeof. So we use that to signal to the - // deleter that it should not perform a deletion and that the memory manager - // will. - size = align = 0; - } - return pointer; -} - -void MemoryManager::DeallocateInternal(void* pointer, size_t size, - size_t align) { - ABSL_ASSERT(pointer != nullptr); - ABSL_ASSERT(size != 0); - ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. - // `size` is the unadjusted size, the original sizeof(T) used during - // allocation. We need to adjust it to match the allocation size. - size = AdjustAllocationSize(size, align); - ControlBlock* control_block = GetControlBlock(pointer, size, align); - MemoryManager* memory_manager = control_block->memory_manager; - if constexpr (!std::is_trivially_destructible_v) { - control_block->~ControlBlock(); - } - pointer = AdjustForDeallocation(pointer, align); - memory_manager->Deallocate(pointer, size, align); -} - -MemoryManager& MemoryManager::Get(const void* pointer, size_t size, - size_t align) { - return *GetControlBlock(pointer, size, align)->memory_manager; -} - -void MemoryManager::Ref(const void* pointer, size_t size, size_t align) { - if (pointer != nullptr && size != 0) { - ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. - // `size` is the unadjusted size, the original sizeof(T) used during - // allocation. We need to adjust it to match the allocation size. - size = AdjustAllocationSize(size, align); - GetControlBlock(pointer, size, align)->Ref(); - } -} - -bool MemoryManager::UnrefInternal(const void* pointer, size_t size, - size_t align) { - bool cleanup = false; - if (pointer != nullptr && size != 0) { - ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. - // `size` is the unadjusted size, the original sizeof(T) used during - // allocation. We need to adjust it to match the allocation size. - size = AdjustAllocationSize(size, align); - cleanup = GetControlBlock(pointer, size, align)->Unref(); - } - return cleanup; -} - -void MemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { - static_cast(pointer); - static_cast(destruct); - // OwnDestructor is only called for arena-based memory managers by `New`. If - // we got here, something is seriously wrong so crashing is okay. - std::abort(); -} - -void ArenaMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { - static_cast(pointer); - static_cast(size); - static_cast(align); - // Most arena-based allocators will not deallocate individual allocations, so - // we default the implementation to std::abort(). - std::abort(); -} - std::unique_ptr ArenaMemoryManager::Default() { return std::make_unique(); } diff --git a/base/memory_manager.h b/base/memory_manager.h index e333fe18b..aa7b56841 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -21,134 +21,16 @@ #include #include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "base/internal/memory_manager.pre.h" // IWYU pragma: export +#include "base/internal/data.h" +#include "base/internal/memory_manager.h" +#include "base/managed_memory.h" namespace cel { class MemoryManager; +class GlobalMemoryManager; class ArenaMemoryManager; -// `ManagedMemory` is a smart pointer which ensures any applicable object -// destructors and deallocation are eventually performed. Copying does not -// actually copy the underlying T, instead a pointer is copied and optionally -// reference counted. Moving does not actually move the underlying T, instead a -// pointer is moved. -// -// TODO(issues/5): consider feature parity with std::unique_ptr -template -class ManagedMemory final { - public: - ManagedMemory() = default; - - ManagedMemory(const ManagedMemory& other) - : ptr_(other.ptr_), size_(other.size_), align_(other.align_) { - Ref(); - } - - ManagedMemory(ManagedMemory&& other) - : ptr_(other.ptr_), size_(other.size_), align_(other.align_) { - other.ptr_ = nullptr; - other.size_ = other.align_ = 0; - } - - ~ManagedMemory() { Unref(); } - - ManagedMemory& operator=(const ManagedMemory& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - other.Ref(); - Unref(); - ptr_ = other.ptr_; - size_ = other.size_; - align_ = other.align_; - } - return *this; - } - - ManagedMemory& operator=(ManagedMemory&& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - reset(); - swap(other); - } - return *this; - } - - T* release() { - ABSL_ASSERT(size_ == 0); - return base_internal::ManagedMemoryRelease(*this); - } - - void reset() { - Unref(); - ptr_ = nullptr; - size_ = align_ = 0; - } - - void swap(ManagedMemory& other) { - std::swap(ptr_, other.ptr_); - std::swap(size_, other.size_); - std::swap(align_, other.align_); - } - - constexpr T* get() const { return ptr_; } - - constexpr T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(static_cast(*this)); - return *get(); - } - - constexpr T* operator->() const { - ABSL_ASSERT(static_cast(*this)); - return get(); - } - - constexpr explicit operator bool() const { return get() != nullptr; } - - private: - friend class MemoryManager; - template - friend constexpr size_t base_internal::GetManagedMemorySize( - const ManagedMemory& managed_memory); - template - friend constexpr size_t base_internal::GetManagedMemoryAlignment( - const ManagedMemory& managed_memory); - template - friend constexpr F* base_internal::ManagedMemoryRelease( - ManagedMemory& managed_memory); - - constexpr ManagedMemory(T* ptr, size_t size, size_t align) - : ptr_(ptr), size_(size), align_(align) {} - - void Ref() const; - - void Unref() const; - - T* ptr_ = nullptr; - size_t size_ = 0; - size_t align_ = 0; -}; - -template -constexpr bool operator==(const ManagedMemory& lhs, std::nullptr_t) { - return !static_cast(lhs); -} - -template -constexpr bool operator==(std::nullptr_t, const ManagedMemory& rhs) { - return !static_cast(rhs); -} - -template -constexpr bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { - return !operator==(lhs, nullptr); -} - -template -constexpr bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { - return !operator==(nullptr, rhs); -} - // `MemoryManager` is an abstraction over memory management that supports // different allocation strategies. class MemoryManager { @@ -160,125 +42,72 @@ class MemoryManager { // Allocates and constructs `T`. In the event of an allocation failure nullptr // is returned. template - ManagedMemory New(Args&&... args) - ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { - size_t size = sizeof(T); - size_t align = alignof(T); - void* pointer = AllocateInternal(size, align); - if (ABSL_PREDICT_TRUE(pointer != nullptr)) { - ::new (pointer) T(std::forward(args)...); + std::enable_if_t, ManagedMemory> + New(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + static_assert(std::is_base_of_v, + "T must only be stored inline"); + if (!allocation_only_) { + T* pointer = new T(std::forward(args)...); + base_internal::Metadata::For(pointer)->SetReferenceCounted(); + return ManagedMemory(pointer); + } + void* pointer = Allocate(sizeof(T), alignof(T)); + ::new (pointer) T(std::forward(args)...); + if constexpr (!std::is_trivially_destructible_v) { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); + } + base_internal::Metadata::For(reinterpret_cast(pointer)) + ->SetArenaAllocated(); + return ManagedMemory(reinterpret_cast(pointer)); + } + + // Allocates and constructs `T`. In the event of an allocation failure nullptr + // is returned. + template + std::enable_if_t>, + ManagedMemory> + New(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + if (!allocation_only_) { + base_internal::ManagedMemoryDestructor destructor = nullptr; if constexpr (!std::is_trivially_destructible_v) { - if (allocation_only_) { - OwnDestructor(pointer, - &base_internal::MemoryManagerDestructor::Destruct); - } + destructor = &base_internal::MemoryManagerDestructor::Destruct; } + auto [state, pointer] = base_internal::ManagedMemoryState::New( + sizeof(T), alignof(T), destructor); + ::new (pointer) T(std::forward(args)...); + return ManagedMemory(reinterpret_cast(pointer), state); + } + void* pointer = Allocate(sizeof(T), alignof(T)); + ::new (pointer) T(std::forward(args)...); + if constexpr (!std::is_trivially_destructible_v) { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); } - return ManagedMemory(reinterpret_cast(pointer), size, align); + return ManagedMemory(reinterpret_cast(pointer), nullptr); } - protected: - MemoryManager() : MemoryManager(false) {} - - template - struct AllocationResult final { - Pointer pointer = nullptr; - }; - private: - template - friend class ManagedMemory; + friend class GlobalMemoryManager; friend class ArenaMemoryManager; - friend class base_internal::Resource; - friend MemoryManager& base_internal::GetMemoryManager(const void* pointer, - size_t size, - size_t align); - // Only for use by ArenaMemoryManager. + // Only for use by GlobalMemoryManager and ArenaMemoryManager. explicit MemoryManager(bool allocation_only) : allocation_only_(allocation_only) {} - static MemoryManager& Get(const void* pointer, size_t size, size_t align); - - void* AllocateInternal(size_t& size, size_t& align); - - static void DeallocateInternal(void* pointer, size_t size, size_t align); - - // Potentially increment the reference count in the control block for the - // previously allocated memory from `New()`. This is intended to be called - // from `ManagedMemory`. - // - // If size is 0, then the allocation was arena-based. - static void Ref(const void* pointer, size_t size, size_t align); - - // Potentially decrement the reference count in the control block for the - // previously allocated memory from `New()`. Returns true if `Delete()` should - // be called. - // - // If size is 0, then the allocation was arena-based and this call is a noop. - static bool UnrefInternal(const void* pointer, size_t size, size_t align); - - // Delete a previous `New()` result when `allocation_only_` is false. - template - static void Delete(T* pointer, size_t size, size_t align) { - if constexpr (!std::is_trivially_destructible_v) { - pointer->~T(); - } - DeallocateInternal( - static_cast(const_cast*>(pointer)), size, - align); - } - - // Potentially decrement the reference count in the control block and - // deallocate the memory for the previously allocated memory from `New()`. - // This is intended to be called from `ManagedMemory`. - // - // If size is 0, then the allocation was arena-based and this call is a noop. - template - static void Unref(T* pointer, size_t size, size_t align) { - if (UnrefInternal(pointer, size, align)) { - Delete(pointer, size, align); - } - } - - // These are virtual private, ensuring only `MemoryManager` calls these. Which - // methods need to be implemented and which are called depends on whether the - // implementation is using arena memory management or not. - // - // If the implementation is using arenas then `Deallocate()` will never be - // called, `OwnDestructor` must be implemented, and `AllocationOnly` must - // return true. If the implementation is *not* using arenas then `Deallocate` - // must be implemented, `OwnDestructor` will never be called, and - // `AllocationOnly` will return false. + // These are virtual private, ensuring only `MemoryManager` calls these. // Allocates memory of at least size `size` in bytes that is at least as // aligned as `align`. - virtual AllocationResult Allocate(size_t size, size_t align) = 0; - - // Deallocate the given pointer previously allocated via `Allocate`, assuming - // `AllocationResult::owned` was true. Calling this when - // `AllocationResult::owned` was false is undefined behavior. - virtual void Deallocate(void* pointer, size_t size, size_t align) = 0; + virtual void* Allocate(size_t size, size_t align) = 0; // Registers a destructor to be run upon destruction of the memory management // implementation. - // - // This method is only valid for arena memory managers. - virtual void OwnDestructor(void* pointer, void (*destruct)(void*)); + virtual void OwnDestructor(void* pointer, void (*destruct)(void*)) = 0; const bool allocation_only_; }; -template -void ManagedMemory::Ref() const { - MemoryManager::Ref(ptr_, size_, align_); -} - -template -void ManagedMemory::Unref() const { - MemoryManager::Unref(ptr_, size_, align_); -} - namespace extensions { class ProtoMemoryManager; } @@ -301,17 +130,8 @@ class ArenaMemoryManager : public MemoryManager { // other derivations of ArenaMemoryManager should be allocation-only. explicit ArenaMemoryManager(bool allocation_only) : MemoryManager(allocation_only) {} - - // Default implementation calls std::abort(). If you have a special case where - // you support deallocating individual allocations, override this. - void Deallocate(void* pointer, size_t size, size_t align) override; - - // OwnDestructor is typically required for arena-based memory managers. - void OwnDestructor(void* pointer, void (*destruct)(void*)) override = 0; }; } // namespace cel -#include "base/internal/memory_manager.post.h" // IWYU pragma: export - #endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ diff --git a/base/type.cc b/base/type.cc index 4d1bb6a6f..cecdfa73d 100644 --- a/base/type.cc +++ b/base/type.cc @@ -17,25 +17,295 @@ #include #include -#include "absl/hash/hash.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "base/handle.h" +#include "base/internal/data.h" +#include "base/types/any_type.h" +#include "base/types/bool_type.h" +#include "base/types/bytes_type.h" +#include "base/types/double_type.h" +#include "base/types/duration_type.h" +#include "base/types/dyn_type.h" +#include "base/types/enum_type.h" +#include "base/types/error_type.h" +#include "base/types/int_type.h" +#include "base/types/list_type.h" +#include "base/types/map_type.h" +#include "base/types/null_type.h" +#include "base/types/string_type.h" +#include "base/types/struct_type.h" +#include "base/types/timestamp_type.h" +#include "base/types/type_type.h" +#include "base/types/uint_type.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(Type); -std::string Type::DebugString() const { return std::string(name()); } +absl::string_view Type::name() const { + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->name(); + case Kind::kError: + return static_cast(this)->name(); + case Kind::kDyn: + return static_cast(this)->name(); + case Kind::kAny: + return static_cast(this)->name(); + case Kind::kType: + return static_cast(this)->name(); + case Kind::kBool: + return static_cast(this)->name(); + case Kind::kInt: + return static_cast(this)->name(); + case Kind::kUint: + return static_cast(this)->name(); + case Kind::kDouble: + return static_cast(this)->name(); + case Kind::kString: + return static_cast(this)->name(); + case Kind::kBytes: + return static_cast(this)->name(); + case Kind::kEnum: + return static_cast(this)->name(); + case Kind::kDuration: + return static_cast(this)->name(); + case Kind::kTimestamp: + return static_cast(this)->name(); + case Kind::kList: + return static_cast(this)->name(); + case Kind::kMap: + return static_cast(this)->name(); + case Kind::kStruct: + return static_cast(this)->name(); + } +} -std::pair Type::SizeAndAlignment() const { - // Currently no implementation of Type is reference counted. However once we - // introduce Struct it likely will be. Using 0 here will trigger runtime - // asserts in case of undefined behavior. Struct should force this to be pure. - return std::pair(0, 0); +std::string Type::DebugString() const { + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->DebugString(); + case Kind::kError: + return static_cast(this)->DebugString(); + case Kind::kDyn: + return static_cast(this)->DebugString(); + case Kind::kAny: + return static_cast(this)->DebugString(); + case Kind::kType: + return static_cast(this)->DebugString(); + case Kind::kBool: + return static_cast(this)->DebugString(); + case Kind::kInt: + return static_cast(this)->DebugString(); + case Kind::kUint: + return static_cast(this)->DebugString(); + case Kind::kDouble: + return static_cast(this)->DebugString(); + case Kind::kString: + return static_cast(this)->DebugString(); + case Kind::kBytes: + return static_cast(this)->DebugString(); + case Kind::kEnum: + return static_cast(this)->DebugString(); + case Kind::kDuration: + return static_cast(this)->DebugString(); + case Kind::kTimestamp: + return static_cast(this)->DebugString(); + case Kind::kList: + return static_cast(this)->DebugString(); + case Kind::kMap: + return static_cast(this)->DebugString(); + case Kind::kStruct: + return static_cast(this)->DebugString(); + } } -bool Type::Equals(const Type& other) const { return kind() == other.kind(); } +bool Type::Equals(const Type& other) const { + if (this == &other) { + return true; + } + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->Equals(other); + case Kind::kError: + return static_cast(this)->Equals(other); + case Kind::kDyn: + return static_cast(this)->Equals(other); + case Kind::kAny: + return static_cast(this)->Equals(other); + case Kind::kType: + return static_cast(this)->Equals(other); + case Kind::kBool: + return static_cast(this)->Equals(other); + case Kind::kInt: + return static_cast(this)->Equals(other); + case Kind::kUint: + return static_cast(this)->Equals(other); + case Kind::kDouble: + return static_cast(this)->Equals(other); + case Kind::kString: + return static_cast(this)->Equals(other); + case Kind::kBytes: + return static_cast(this)->Equals(other); + case Kind::kEnum: + return static_cast(this)->Equals(other); + case Kind::kDuration: + return static_cast(this)->Equals(other); + case Kind::kTimestamp: + return static_cast(this)->Equals(other); + case Kind::kList: + return static_cast(this)->Equals(other); + case Kind::kMap: + return static_cast(this)->Equals(other); + case Kind::kStruct: + return static_cast(this)->Equals(other); + } +} void Type::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), kind(), name()); + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kError: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kDyn: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kAny: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kType: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kBool: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kInt: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kUint: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kDouble: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kString: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kBytes: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kEnum: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kDuration: + return static_cast(this)->HashValue( + std::move(state)); + case Kind::kTimestamp: + return static_cast(this)->HashValue( + std::move(state)); + case Kind::kList: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kMap: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kStruct: + return static_cast(this)->HashValue(std::move(state)); + } +} + +namespace base_internal { + +bool PersistentTypeHandle::Equals(const PersistentTypeHandle& other) const { + const auto* self = static_cast(data_.get()); + const auto* that = static_cast(other.data_.get()); + if (self == that) { + return true; + } + if (self == nullptr || that == nullptr) { + return false; + } + return self->Equals(*that); +} + +void PersistentTypeHandle::HashValue(absl::HashState state) const { + if (const auto* pointer = static_cast(data_.get()); + ABSL_PREDICT_TRUE(pointer != nullptr)) { + pointer->HashValue(std::move(state)); + } +} + +void PersistentTypeHandle::CopyFrom(const PersistentTypeHandle& other) { + // data_ is currently uninitialized. + auto locality = other.data_.locality(); + if (ABSL_PREDICT_FALSE(locality == DataLocality::kStoredInline && + !other.data_.IsTriviallyCopyable())) { + // Type currently has only trivially copyable inline + // representations. + ABSL_INTERNAL_UNREACHABLE; + } else { + // We can simply just copy the bytes. + data_.CopyFrom(other.data_); + if (locality == DataLocality::kReferenceCounted) { + Ref(); + } + } +} + +void PersistentTypeHandle::MoveFrom(PersistentTypeHandle& other) { + // data_ is currently uninitialized. + auto locality = other.data_.locality(); + if (ABSL_PREDICT_FALSE(locality == DataLocality::kStoredInline && + !other.data_.IsTriviallyCopyable())) { + // Type currently has only trivially copyable inline + // representations. + ABSL_INTERNAL_UNREACHABLE; + } else { + // We can simply just copy the bytes. + data_.MoveFrom(other.data_); + } +} + +void PersistentTypeHandle::CopyAssign(const PersistentTypeHandle& other) { + // data_ is initialized. + Destruct(); + CopyFrom(other); +} + +void PersistentTypeHandle::MoveAssign(PersistentTypeHandle& other) { + // data_ is initialized. + Destruct(); + MoveFrom(other); } +void PersistentTypeHandle::Destruct() { + switch (data_.locality()) { + case DataLocality::kNull: + break; + case DataLocality::kStoredInline: + if (ABSL_PREDICT_FALSE(!data_.IsTriviallyDestructible())) { + // Type currently has only trivially destructible inline + // representations. + ABSL_INTERNAL_UNREACHABLE; + } + break; + case DataLocality::kReferenceCounted: + Unref(); + break; + case DataLocality::kArenaAllocated: + break; + } +} + +void PersistentTypeHandle::Delete() const { + switch (data_.kind()) { + case Kind::kList: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kMap: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kEnum: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kStruct: + delete static_cast(static_cast(data_.get())); + break; + default: + ABSL_INTERNAL_UNREACHABLE; + } +} + +} // namespace base_internal + } // namespace cel diff --git a/base/type.h b/base/type.h index f6d7d4f3c..f3a9cb62c 100644 --- a/base/type.h +++ b/base/type.h @@ -15,130 +15,320 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ +#include #include +#include #include +#include "absl/base/attributes.h" #include "absl/hash/hash.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "absl/utility/utility.h" #include "base/handle.h" -#include "base/internal/type.pre.h" // IWYU pragma: export +#include "base/internal/data.h" +#include "base/internal/type.h" // IWYU pragma: export #include "base/kind.h" -#include "base/memory_manager.h" +#include "internal/casts.h" // IWYU pragma: keep namespace cel { -class Type; -class NullType; -class ErrorType; -class DynType; -class AnyType; -class BoolType; -class IntType; -class UintType; -class DoubleType; -class StringType; -class BytesType; -class DurationType; -class TimestampType; class EnumType; +class StructType; class ListType; class MapType; -class TypeType; class TypeFactory; class TypeProvider; class TypeManager; -class NullValue; -class ErrorValue; -class BoolValue; -class IntValue; -class UintValue; -class DoubleValue; -class BytesValue; -class StringValue; -class DurationValue; -class TimestampValue; -class EnumValue; -class StructValue; -class TypeValue; class ValueFactory; class TypedEnumValueFactory; class TypedStructValueFactory; -namespace internal { -template -class NoDestructor; -} - // A representation of a CEL type that enables reflection, for static analysis, // and introspection, for program construction, of types. -class Type : public base_internal::Resource { +class Type : public base_internal::Data { public: + static bool Is(const Type& type ABSL_ATTRIBUTE_UNUSED) { return true; } + // Returns the type kind. - virtual Kind kind() const = 0; + Kind kind() const { return base_internal::Metadata::For(this)->kind(); } // Returns the type name, i.e. "list". - virtual absl::string_view name() const = 0; + absl::string_view name() const; - virtual std::string DebugString() const; + std::string DebugString() const; - // Called by base_internal::TypeHandleBase. - // Note GCC does not consider a friend member as a member of a friend. - virtual bool Equals(const Type& other) const; + void HashValue(absl::HashState state) const; - // Called by base_internal::TypeHandleBase. - // Note GCC does not consider a friend member as a member of a friend. - virtual void HashValue(absl::HashState state) const; + bool Equals(const Type& other) const; private: - friend class NullType; - friend class ErrorType; - friend class DynType; - friend class AnyType; - friend class BoolType; - friend class IntType; - friend class UintType; - friend class DoubleType; - friend class StringType; - friend class BytesType; - friend class DurationType; - friend class TimestampType; friend class EnumType; friend class StructType; friend class ListType; friend class MapType; - friend class TypeType; - friend class base_internal::TypeHandleBase; + template + friend class base_internal::SimpleType; Type() = default; Type(const Type&) = default; Type(Type&&) = default; + Type& operator=(const Type&) = default; + Type& operator=(Type&&) = default; +}; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return true; } +template +H AbslHashValue(H state, const Type& type) { + type.HashValue(absl::HashState::Create(&state)); + return state; +} - // For non-inlined types that are reference counted, this is the result of - // `sizeof` and `alignof` for the most derived class. - std::pair SizeAndAlignment() const override; +inline bool operator==(const Type& lhs, const Type& rhs) { + return lhs.Equals(rhs); +} - using base_internal::Resource::Ref; - using base_internal::Resource::Unref; -}; +inline bool operator!=(const Type& lhs, const Type& rhs) { + return !operator==(lhs, rhs); +} } // namespace cel -// type.pre.h forward declares types so they can be friended above. The types -// themselves need to be defined after everything else as they need to access or -// derive from the above types. We do this in type.post.h to avoid mudying this -// header and making it difficult to read. -#include "base/internal/type.post.h" // IWYU pragma: export +// ----------------------------------------------------------------------------- +// Internal implementation details. namespace cel { +namespace base_internal { + +class PersistentTypeHandle final { + public: + PersistentTypeHandle() = default; + + template + explicit PersistentTypeHandle(absl::in_place_type_t in_place_type, + Args&&... args) { + data_.ConstructInline(std::forward(args)...); + } + + explicit PersistentTypeHandle(const Type& type) { data_.ConstructHeap(type); } + + PersistentTypeHandle(const PersistentTypeHandle& other) { CopyFrom(other); } + + PersistentTypeHandle(PersistentTypeHandle&& other) { MoveFrom(other); } + + ~PersistentTypeHandle() { Destruct(); } + + PersistentTypeHandle& operator=(const PersistentTypeHandle& other) { + if (this != &other) { + CopyAssign(other); + } + return *this; + } + + PersistentTypeHandle& operator=(PersistentTypeHandle&& other) { + if (this != &other) { + MoveAssign(other); + } + return *this; + } + + Type* get() const { return static_cast(data_.get()); } + + explicit operator bool() const { return !data_.IsNull(); } + + bool Equals(const PersistentTypeHandle& other) const; + + void HashValue(absl::HashState state) const; + + private: + void CopyFrom(const PersistentTypeHandle& other); + + void MoveFrom(PersistentTypeHandle& other); + + void CopyAssign(const PersistentTypeHandle& other); + + void MoveAssign(PersistentTypeHandle& other); + + void Ref() const { data_.Ref(); } + + void Unref() const { + if (data_.Unref()) { + Delete(); + } + } + + void Destruct(); + + void Delete() const; + + AnyType data_; +}; + +template +H AbslHashValue(H state, const PersistentTypeHandle& handle) { + handle.HashValue(absl::HashState::Create(&state)); + return state; +} + +inline bool operator==(const PersistentTypeHandle& lhs, + const PersistentTypeHandle& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator!=(const PersistentTypeHandle& lhs, + const PersistentTypeHandle& rhs) { + return !operator==(lhs, rhs); +} + +// Specialization for Type providing the implementation to `Persistent`. +template <> +struct HandleTraits { + using handle_type = PersistentTypeHandle; +}; + +// Partial specialization for `Persistent` for all classes derived from Type. +template +struct HandleTraits< + HandleType::kPersistent, T, + std::enable_if_t<(std::is_base_of_v && !std::is_same_v)>> + final : public HandleTraits {}; + +template +struct SimpleTypeName; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "null_type"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "*error*"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "dyn"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "google.protobuf.Any"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "bool"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "int"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "uint"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "double"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "bytes"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "string"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "google.protobuf.Duration"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "google.protobuf.Timestamp"; +}; + +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "type"; +}; + +template +class SimpleType : public Type, public InlineData { + public: + static constexpr Kind kKind = K; + static constexpr absl::string_view kName = SimpleTypeName::value; + + static bool Is(const Type& type) { return type.kind() == kKind; } + + SimpleType() = default; + SimpleType(const SimpleType&) = default; + SimpleType(SimpleType&&) = default; + SimpleType& operator=(const SimpleType&) = default; + SimpleType& operator=(SimpleType&&) = default; + + constexpr Kind kind() const { return kKind; } + + constexpr absl::string_view name() const { return kName; } + + std::string DebugString() const { return std::string(name()); } + + void HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), kind(), name()); + } + + bool Equals(const Type& other) const { return kind() == other.kind(); } + + private: + friend class PersistentTypeHandle; + + static constexpr uintptr_t kVirtualPointer = + kStoredInline | kTriviallyCopyable | kTriviallyDestructible | + (static_cast(kKind) << kKindShift); + + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; +}; + +} // namespace base_internal + CEL_INTERNAL_TYPE_DECL(Type); } // namespace cel +#define CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(type_class, value_class) \ + private: \ + friend class value_class; \ + friend class TypeFactory; \ + friend class base_internal::PersistentTypeHandle; \ + template \ + friend class base_internal::SimpleValue; \ + template \ + friend class base_internal::AnyData; \ + \ + ABSL_ATTRIBUTE_PURE_FUNCTION static const Persistent& \ + Get(); \ + \ + type_class() = default; \ + type_class(const type_class&) = default; \ + type_class(type_class&&) = default; \ + type_class& operator=(const type_class&) = default; \ + type_class& operator=(type_class&&) = default + +#define CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(type_class) \ + static_assert(std::is_trivially_copyable_v, \ + #type_class " must be trivially copyable"); \ + static_assert(std::is_trivially_destructible_v, \ + #type_class " must be trivially destructible"); \ + \ + CEL_INTERNAL_TYPE_DECL(type_class) + #endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ diff --git a/base/type_factory.cc b/base/type_factory.cc index 5ad1ef789..2b257e517 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -29,93 +29,50 @@ using base_internal::PersistentHandleFactory; } // namespace -namespace base_internal { - -class ListTypeImpl final : public ListType { - public: - explicit ListTypeImpl(Persistent element) - : element_(std::move(element)) {} - - Persistent element() const override { return element_; } - - private: - std::pair SizeAndAlignment() const override { - return std::make_pair(sizeof(ListTypeImpl), alignof(ListTypeImpl)); - } - - Persistent element_; -}; - -class MapTypeImpl final : public MapType { - public: - MapTypeImpl(Persistent key, Persistent value) - : key_(std::move(key)), value_(std::move(value)) {} - - Persistent key() const override { return key_; } - - Persistent value() const override { return value_; } - - private: - std::pair SizeAndAlignment() const override { - return std::make_pair(sizeof(MapTypeImpl), alignof(MapTypeImpl)); - } - - Persistent key_; - Persistent value_; -}; - -} // namespace base_internal - Persistent TypeFactory::GetNullType() { - return WrapSingletonType(); + return NullType::Get(); } Persistent TypeFactory::GetErrorType() { - return WrapSingletonType(); + return ErrorType::Get(); } -Persistent TypeFactory::GetDynType() { - return WrapSingletonType(); -} +Persistent TypeFactory::GetDynType() { return DynType::Get(); } -Persistent TypeFactory::GetAnyType() { - return WrapSingletonType(); -} +Persistent TypeFactory::GetAnyType() { return AnyType::Get(); } Persistent TypeFactory::GetBoolType() { - return WrapSingletonType(); + return BoolType::Get(); } -Persistent TypeFactory::GetIntType() { - return WrapSingletonType(); -} +Persistent TypeFactory::GetIntType() { return IntType::Get(); } Persistent TypeFactory::GetUintType() { - return WrapSingletonType(); + return UintType::Get(); } Persistent TypeFactory::GetDoubleType() { - return WrapSingletonType(); + return DoubleType::Get(); } Persistent TypeFactory::GetStringType() { - return WrapSingletonType(); + return StringType::Get(); } Persistent TypeFactory::GetBytesType() { - return WrapSingletonType(); + return BytesType::Get(); } Persistent TypeFactory::GetDurationType() { - return WrapSingletonType(); + return DurationType::Get(); } Persistent TypeFactory::GetTimestampType() { - return WrapSingletonType(); + return TimestampType::Get(); } Persistent TypeFactory::GetTypeType() { - return WrapSingletonType(); + return TypeType::Get(); } absl::StatusOr> TypeFactory::CreateListType( @@ -125,8 +82,8 @@ absl::StatusOr> TypeFactory::CreateListType( if (existing != list_types_.end()) { return existing->second; } - auto list_type = PersistentHandleFactory::Make< - const base_internal::ListTypeImpl>(memory_manager(), element); + auto list_type = PersistentHandleFactory::Make( + memory_manager(), element); if (ABSL_PREDICT_FALSE(!list_type)) { // TODO(issues/5): maybe have the handle factories return statuses as // they can add details on the size and alignment more easily and @@ -145,8 +102,8 @@ absl::StatusOr> TypeFactory::CreateMapType( if (existing != map_types_.end()) { return existing->second; } - auto map_type = PersistentHandleFactory::Make< - const base_internal::MapTypeImpl>(memory_manager(), key, value); + auto map_type = PersistentHandleFactory::Make( + memory_manager(), key, value); if (ABSL_PREDICT_FALSE(!map_type)) { // TODO(issues/5): maybe have the handle factories return statuses as // they can add details on the size and alignment more easily and diff --git a/base/type_factory.h b/base/type_factory.h index 02cc1674a..b96f20b1f 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -115,16 +115,6 @@ class TypeFactory final { MemoryManager& memory_manager() const { return memory_manager_; } private: - template - static Persistent WrapSingletonType() { - // This is not normal, but we treat the underlying object as having been - // arena allocated. The only way to do this is through - // TransientHandleFactory. - return Persistent( - base_internal::PersistentHandleFactory::template MakeUnmanaged< - const T>(T::Get())); - } - MemoryManager& memory_manager_; absl::Mutex list_types_mutex_; diff --git a/base/types/any_type.cc b/base/types/any_type.cc index e1ba938da..4c7c51f4a 100644 --- a/base/types/any_type.cc +++ b/base/types/any_type.cc @@ -14,15 +14,28 @@ #include "base/types/any_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(AnyType); -const AnyType& AnyType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& AnyType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/any_type.h b/base/types/any_type.h index 5ba77622e..748809ff9 100644 --- a/base/types/any_type.h +++ b/base/types/any_type.h @@ -15,41 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ANY_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_ANY_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class AnyType final : public Type { +class AnyValue; + +class AnyType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kAny; } + using Base::kKind; - absl::string_view name() const override { return "google.protobuf.Any"; } + using Base::kName; - private: - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kAny; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const AnyType& Get(); + using Base::DebugString; - AnyType() = default; + using Base::HashValue; - AnyType(const AnyType&) = delete; - AnyType(AnyType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(AnyType, AnyValue); }; -CEL_INTERNAL_TYPE_DECL(AnyType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(AnyType); } // namespace cel diff --git a/base/types/bool_type.cc b/base/types/bool_type.cc index 6b81034e7..16a4174f8 100644 --- a/base/types/bool_type.cc +++ b/base/types/bool_type.cc @@ -14,15 +14,28 @@ #include "base/types/bool_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(BoolType); -const BoolType& BoolType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& BoolType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/bool_type.h b/base/types/bool_type.h index cee2a43f6..b50ce107d 100644 --- a/base/types/bool_type.h +++ b/base/types/bool_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_BOOL_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_BOOL_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class BoolType final : public Type { +class BoolValue; + +class BoolType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kBool; } + using Base::kKind; - absl::string_view name() const override { return "bool"; } + using Base::kName; - private: - friend class BoolValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kBool; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const BoolType& Get(); + using Base::DebugString; - BoolType() = default; + using Base::HashValue; - BoolType(const BoolType&) = delete; - BoolType(BoolType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(BoolType, BoolValue); }; -CEL_INTERNAL_TYPE_DECL(BoolType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(BoolType); } // namespace cel diff --git a/base/types/bytes_type.cc b/base/types/bytes_type.cc index f174af975..bff9aa5fa 100644 --- a/base/types/bytes_type.cc +++ b/base/types/bytes_type.cc @@ -14,15 +14,28 @@ #include "base/types/bytes_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(BytesType); -const BytesType& BytesType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& BytesType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/bytes_type.h b/base/types/bytes_type.h index 412f684d7..dacd8cd3c 100644 --- a/base/types/bytes_type.h +++ b/base/types/bytes_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_BYTES_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_BYTES_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class BytesType final : public Type { +class BytesValue; + +class BytesType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kBytes; } + using Base::kKind; - absl::string_view name() const override { return "bytes"; } + using Base::kName; - private: - friend class BytesValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kBytes; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const BytesType& Get(); + using Base::DebugString; - BytesType() = default; + using Base::HashValue; - BytesType(const BytesType&) = delete; - BytesType(BytesType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(BytesType, BytesValue); }; -CEL_INTERNAL_TYPE_DECL(BytesType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(BytesType); } // namespace cel diff --git a/base/types/double_type.cc b/base/types/double_type.cc index 8735f51b4..39f58ae08 100644 --- a/base/types/double_type.cc +++ b/base/types/double_type.cc @@ -14,15 +14,28 @@ #include "base/types/double_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(DoubleType); -const DoubleType& DoubleType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& DoubleType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt< + DoubleType>(&instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/double_type.h b/base/types/double_type.h index 946cbe080..12589bd76 100644 --- a/base/types/double_type.h +++ b/base/types/double_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DOUBLE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_DOUBLE_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class DoubleType final : public Type { +class DoubleValue; + +class DoubleType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kDouble; } + using Base::kKind; - absl::string_view name() const override { return "double"; } + using Base::kName; - private: - friend class DoubleValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDouble; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const DoubleType& Get(); + using Base::DebugString; - DoubleType() = default; + using Base::HashValue; - DoubleType(const DoubleType&) = delete; - DoubleType(DoubleType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(DoubleType, DoubleValue); }; -CEL_INTERNAL_TYPE_DECL(DoubleType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(DoubleType); } // namespace cel diff --git a/base/types/duration_type.cc b/base/types/duration_type.cc index f7617b722..e1486f360 100644 --- a/base/types/duration_type.cc +++ b/base/types/duration_type.cc @@ -14,15 +14,28 @@ #include "base/types/duration_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(DurationType); -const DurationType& DurationType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& DurationType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt< + DurationType>(&instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/duration_type.h b/base/types/duration_type.h index b6855751a..79a91c12d 100644 --- a/base/types/duration_type.h +++ b/base/types/duration_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DURATION_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_DURATION_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class DurationType final : public Type { +class DurationValue; + +class DurationType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kDuration; } + using Base::kKind; - absl::string_view name() const override { return "google.protobuf.Duration"; } + using Base::kName; - private: - friend class DurationValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDuration; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const DurationType& Get(); + using Base::DebugString; - DurationType() = default; + using Base::HashValue; - DurationType(const DurationType&) = delete; - DurationType(DurationType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(DurationType, DurationValue); }; -CEL_INTERNAL_TYPE_DECL(DurationType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(DurationType); } // namespace cel diff --git a/base/types/dyn_type.cc b/base/types/dyn_type.cc index c9a12e3ae..dbca71e4a 100644 --- a/base/types/dyn_type.cc +++ b/base/types/dyn_type.cc @@ -14,15 +14,28 @@ #include "base/types/dyn_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(DynType); -const DynType& DynType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& DynType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/dyn_type.h b/base/types/dyn_type.h index 7e5439ddd..448caba2f 100644 --- a/base/types/dyn_type.h +++ b/base/types/dyn_type.h @@ -15,41 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class DynType final : public Type { +class DynValue; + +class DynType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kDyn; } + using Base::kKind; - absl::string_view name() const override { return "dyn"; } + using Base::kName; - private: - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kDyn; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const DynType& Get(); + using Base::DebugString; - DynType() = default; + using Base::HashValue; - DynType(const DynType&) = delete; - DynType(DynType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(DynType, DynValue); }; -CEL_INTERNAL_TYPE_DECL(DynType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(DynType); } // namespace cel diff --git a/base/types/enum_type.cc b/base/types/enum_type.cc index ad6dc0371..6b3692299 100644 --- a/base/types/enum_type.cc +++ b/base/types/enum_type.cc @@ -14,6 +14,10 @@ #include "base/types/enum_type.h" +#include + +#include "absl/base/macros.h" +#include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" @@ -22,6 +26,13 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(EnumType); +EnumType::EnumType() : base_internal::HeapData(kKind) { + // Ensure `Type*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + struct EnumType::FindConstantVisitor final { const EnumType& enum_type; @@ -38,4 +49,14 @@ absl::StatusOr EnumType::FindConstant(ConstantId id) const { return absl::visit(FindConstantVisitor{*this}, id.data_); } +void EnumType::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), kind(), name(), TypeId()); +} + +bool EnumType::Equals(const Type& other) const { + return kind() == other.kind() && + name() == static_cast(other).name() && + TypeId() == static_cast(other).TypeId(); +} + } // namespace cel diff --git a/base/types/enum_type.h b/base/types/enum_type.h index 500318751..31ca0763d 100644 --- a/base/types/enum_type.h +++ b/base/types/enum_type.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ENUM_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_ENUM_TYPE_H_ +#include #include #include #include @@ -23,18 +24,21 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" #include "internal/rtti.h" namespace cel { +class MemoryManager; +class EnumValue; class TypedEnumValueFactory; class TypeManager; // EnumType represents an enumeration type. An enumeration is a set of constants // that can be looked up by name and/or number. -class EnumType : public Type { +class EnumType : public Type, public base_internal::HeapData { public: struct Constant; @@ -58,13 +62,25 @@ class EnumType : public Type { absl::variant data_; }; - Kind kind() const final { return Kind::kEnum; } + static constexpr Kind kKind = Kind::kEnum; + + static bool Is(const Type& type) { return type.kind() == kKind; } + + Kind kind() const { return kKind; } + + virtual absl::string_view name() const = 0; + + std::string DebugString() const { return std::string(name()); } + + virtual void HashValue(absl::HashState state) const; + + virtual bool Equals(const Type& other) const; // Find the constant definition for the given identifier. absl::StatusOr FindConstant(ConstantId id) const; protected: - EnumType() = default; + EnumType(); // Construct a new instance of EnumValue with a type of this. Called by // EnumValue::New. @@ -92,19 +108,14 @@ class EnumType : public Type { friend struct NewInstanceVisitor; friend struct FindConstantVisitor; + friend class MemoryManager; friend class EnumValue; friend class TypeFactory; - friend class base_internal::TypeHandleBase; - - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kEnum; } + friend class base_internal::PersistentTypeHandle; EnumType(const EnumType&) = delete; EnumType(EnumType&&) = delete; - std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_ENUM_TYPE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; }; diff --git a/base/types/error_type.cc b/base/types/error_type.cc index f45c69d31..eefd2d24d 100644 --- a/base/types/error_type.cc +++ b/base/types/error_type.cc @@ -14,15 +14,28 @@ #include "base/types/error_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(ErrorType); -const ErrorType& ErrorType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& ErrorType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/error_type.h b/base/types/error_type.h index 96059aa0e..c44db8857 100644 --- a/base/types/error_type.h +++ b/base/types/error_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_ERROR_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_ERROR_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class ErrorType final : public Type { +class ErrorValue; + +class ErrorType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kError; } + using Base::kKind; - absl::string_view name() const override { return "*error*"; } + using Base::kName; - private: - friend class ErrorValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kError; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const ErrorType& Get(); + using Base::DebugString; - ErrorType() = default; + using Base::HashValue; - ErrorType(const ErrorType&) = delete; - ErrorType(ErrorType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(ErrorType, ErrorValue); }; -CEL_INTERNAL_TYPE_DECL(ErrorType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(ErrorType); } // namespace cel diff --git a/base/types/int_type.cc b/base/types/int_type.cc index e722f8f31..06ca432ff 100644 --- a/base/types/int_type.cc +++ b/base/types/int_type.cc @@ -14,15 +14,28 @@ #include "base/types/int_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(IntType); -const IntType& IntType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& IntType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/int_type.h b/base/types/int_type.h index 579f41859..796b9d059 100644 --- a/base/types/int_type.h +++ b/base/types/int_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_INT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_INT_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class IntType final : public Type { +class IntValue; + +class IntType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kInt; } + using Base::kKind; - absl::string_view name() const override { return "int"; } + using Base::kName; - private: - friend class IntValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kInt; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const IntType& Get(); + using Base::DebugString; - IntType() = default; + using Base::HashValue; - IntType(const IntType&) = delete; - IntType(IntType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(IntType, IntValue); }; -CEL_INTERNAL_TYPE_DECL(IntType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(IntType); } // namespace cel diff --git a/base/types/list_type.cc b/base/types/list_type.cc index 4f502fbf4..7c5181078 100644 --- a/base/types/list_type.cc +++ b/base/types/list_type.cc @@ -14,12 +14,24 @@ #include "base/types/list_type.h" +#include +#include + +#include "absl/base/macros.h" #include "absl/strings/str_cat.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(ListType); +ListType::ListType(Persistent element) + : base_internal::HeapData(kKind), element_(std::move(element)) { + // Ensure `Type*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + std::string ListType::DebugString() const { return absl::StrCat(name(), "(", element()->DebugString(), ")"); } @@ -28,13 +40,13 @@ bool ListType::Equals(const Type& other) const { if (kind() != other.kind()) { return false; } - return element() == internal::down_cast(other).element(); + return element() == static_cast(other).element(); } void ListType::HashValue(absl::HashState state) const { // We specifically hash the element first and then call the parent method to // avoid hash suffix/prefix collisions. - Type::HashValue(absl::HashState::combine(std::move(state), element())); + absl::HashState::combine(std::move(state), element(), kind(), name()); } } // namespace cel diff --git a/base/types/list_type.h b/base/types/list_type.h index d339c6229..f55bcc618 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -15,55 +15,54 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_LIST_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_LIST_TYPE_H_ +#include #include #include #include #include "absl/hash/hash.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" namespace cel { +class MemoryManager; + // ListType represents a list type. A list is a sequential container where each // element is the same type. -class ListType : public Type { +class ListType final : public Type, public base_internal::HeapData { // I would have liked to make this class final, but we cannot instantiate // Persistent or Transient at this point. It must be // done after the post include below. Maybe we should separate out the post // includes on a per type basis so we can do that? public: - Kind kind() const final { return Kind::kList; } + static constexpr Kind kKind = Kind::kList; - absl::string_view name() const final { return "list"; } + static bool Is(const Type& type) { return type.kind() == kKind; } - std::string DebugString() const final; + Kind kind() const { return kKind; } - // Returns the type of the elements in the list. - virtual Persistent element() const = 0; + absl::string_view name() const { return KindToString(kind()); } - private: - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - friend class base_internal::ListTypeImpl; + std::string DebugString() const; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kList; } + void HashValue(absl::HashState state) const; - ListType() = default; + bool Equals(const Type& other) const; - ListType(const ListType&) = delete; - ListType(ListType&&) = delete; + // Returns the type of the elements in the list. + const Persistent& element() const { return element_; } - std::pair SizeAndAlignment() const override = 0; + private: + friend class MemoryManager; + friend class TypeFactory; + friend class base_internal::PersistentTypeHandle; - // Called by base_internal::TypeHandleBase. - bool Equals(const Type& other) const final; + explicit ListType(Persistent element); - // Called by base_internal::TypeHandleBase. - void HashValue(absl::HashState state) const final; + const Persistent element_; }; CEL_INTERNAL_TYPE_DECL(ListType); diff --git a/base/types/map_type.cc b/base/types/map_type.cc index a55f275ca..3bb249591 100644 --- a/base/types/map_type.cc +++ b/base/types/map_type.cc @@ -14,12 +14,26 @@ #include "base/types/map_type.h" +#include +#include + +#include "absl/base/macros.h" #include "absl/strings/str_cat.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(MapType); +MapType::MapType(Persistent key, Persistent value) + : base_internal::HeapData(kKind), + key_(std::move(key)), + value_(std::move(value)) { + // Ensure `Type*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + std::string MapType::DebugString() const { return absl::StrCat(name(), "(", key()->DebugString(), ", ", value()->DebugString(), ")"); @@ -29,14 +43,14 @@ bool MapType::Equals(const Type& other) const { if (kind() != other.kind()) { return false; } - return key() == internal::down_cast(other).key() && - value() == internal::down_cast(other).value(); + return key() == static_cast(other).key() && + value() == static_cast(other).value(); } void MapType::HashValue(absl::HashState state) const { // We specifically hash the element first and then call the parent method to // avoid hash suffix/prefix collisions. - Type::HashValue(absl::HashState::combine(std::move(state), key(), value())); + absl::HashState::combine(std::move(state), key(), value(), kind(), name()); } } // namespace cel diff --git a/base/types/map_type.h b/base/types/map_type.h index 9da1cb0f6..41be1429a 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -15,60 +15,59 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_MAP_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_MAP_TYPE_H_ +#include #include #include #include #include "absl/hash/hash.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" namespace cel { +class MemoryManager; class TypeFactory; // MapType represents a map type. A map is container of key and value pairs // where each key appears at most once. -class MapType : public Type { +class MapType final : public Type, public base_internal::HeapData { // I would have liked to make this class final, but we cannot instantiate // Persistent or Transient at this point. It must be // done after the post include below. Maybe we should separate out the post // includes on a per type basis so we can do that? public: - Kind kind() const final { return Kind::kMap; } + static constexpr Kind kKind = Kind::kMap; - absl::string_view name() const final { return "map"; } + static bool Is(const Type& type) { return type.kind() == kKind; } - std::string DebugString() const final; + Kind kind() const { return kKind; } - // Returns the type of the keys in the map. - virtual Persistent key() const = 0; + absl::string_view name() const { return KindToString(kind()); } - // Returns the type of the values in the map. - virtual Persistent value() const = 0; + std::string DebugString() const; - private: - friend class TypeFactory; - friend class base_internal::TypeHandleBase; - friend class base_internal::MapTypeImpl; + void HashValue(absl::HashState state) const; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kMap; } + bool Equals(const Type& other) const; - MapType() = default; + // Returns the type of the keys in the map. + const Persistent& key() const { return key_; } - MapType(const MapType&) = delete; - MapType(MapType&&) = delete; + // Returns the type of the values in the map. + const Persistent& value() const { return value_; } - std::pair SizeAndAlignment() const override = 0; + private: + friend class MemoryManager; + friend class TypeFactory; + friend class base_internal::PersistentTypeHandle; - // Called by base_internal::TypeHandleBase. - bool Equals(const Type& other) const final; + explicit MapType(Persistent key, Persistent value); - // Called by base_internal::TypeHandleBase. - void HashValue(absl::HashState state) const final; + const Persistent key_; + const Persistent value_; }; CEL_INTERNAL_TYPE_DECL(MapType); diff --git a/base/types/null_type.cc b/base/types/null_type.cc index b964cd2f1..8e97a6624 100644 --- a/base/types/null_type.cc +++ b/base/types/null_type.cc @@ -14,15 +14,28 @@ #include "base/types/null_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(NullType); -const NullType& NullType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& NullType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/null_type.h b/base/types/null_type.h index 4544f73d5..8d1d96f55 100644 --- a/base/types/null_type.h +++ b/base/types/null_type.h @@ -15,43 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_NULL_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_NULL_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class NullType final : public Type { +class NullValue; + +class NullType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kNullType; } + using Base::kKind; - absl::string_view name() const override { return "null_type"; } + using Base::kName; - // Note GCC does not consider a friend member as a member of a friend. - ABSL_ATTRIBUTE_PURE_FUNCTION static const NullType& Get(); + using Base::Is; - private: - friend class NullValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::kind; + + using Base::name; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kNullType; } + using Base::DebugString; - NullType() = default; + using Base::HashValue; - NullType(const NullType&) = delete; - NullType(NullType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(NullType, NullValue); }; -CEL_INTERNAL_TYPE_DECL(NullType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(NullType); } // namespace cel diff --git a/base/types/string_type.cc b/base/types/string_type.cc index 57be9ac16..51c42cf4c 100644 --- a/base/types/string_type.cc +++ b/base/types/string_type.cc @@ -14,15 +14,28 @@ #include "base/types/string_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(StringType); -const StringType& StringType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& StringType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt< + StringType>(&instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/string_type.h b/base/types/string_type.h index eb75cdbf2..ed2e8885a 100644 --- a/base/types/string_type.h +++ b/base/types/string_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_STRING_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_STRING_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class StringType final : public Type { +class StringValue; + +class StringType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kString; } + using Base::kKind; - absl::string_view name() const override { return "string"; } + using Base::kName; - private: - friend class StringValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kString; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const StringType& Get(); + using Base::DebugString; - StringType() = default; + using Base::HashValue; - StringType(const StringType&) = delete; - StringType(StringType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(StringType, StringValue); }; -CEL_INTERNAL_TYPE_DECL(StringType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(StringType); } // namespace cel diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc index 0b5c20ece..633af5d2f 100644 --- a/base/types/struct_type.cc +++ b/base/types/struct_type.cc @@ -14,6 +14,10 @@ #include "base/types/struct_type.h" +#include + +#include "absl/base/macros.h" +#include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" @@ -22,6 +26,13 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(StructType); +StructType::StructType() : base_internal::HeapData(kKind) { + // Ensure `Type*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + struct StructType::FindFieldVisitor final { const StructType& struct_type; TypeManager& type_manager; @@ -40,4 +51,14 @@ absl::StatusOr StructType::FindField( return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); } +void StructType::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), kind(), name(), TypeId()); +} + +bool StructType::Equals(const Type& other) const { + return kind() == other.kind() && + name() == static_cast(other).name() && + TypeId() == static_cast(other).TypeId(); +} + } // namespace cel diff --git a/base/types/struct_type.h b/base/types/struct_type.h index ef9acfb7f..5cfdfe34f 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_STRUCT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_STRUCT_TYPE_H_ +#include #include #include #include @@ -29,12 +30,14 @@ namespace cel { +class MemoryManager; +class StructValue; class TypedStructValueFactory; class TypeManager; // StructType represents an struct type. An struct is a set of fields // that can be looked up by name and/or number. -class StructType : public Type { +class StructType : public Type, public base_internal::HeapData { public: struct Field; @@ -58,13 +61,25 @@ class StructType : public Type { absl::variant data_; }; - Kind kind() const final { return Kind::kStruct; } + static constexpr Kind kKind = Kind::kStruct; + + static bool Is(const Type& type) { return type.kind() == kKind; } + + Kind kind() const { return kKind; } + + virtual absl::string_view name() const = 0; + + std::string DebugString() const { return std::string(name()); } + + virtual void HashValue(absl::HashState state) const; + + virtual bool Equals(const Type& other) const; // Find the field definition for the given identifier. absl::StatusOr FindField(TypeManager& type_manager, FieldId id) const; protected: - StructType() = default; + StructType(); virtual absl::StatusOr> NewInstance( TypedStructValueFactory& factory) const = 0; @@ -83,19 +98,14 @@ class StructType : public Type { struct FindFieldVisitor; friend struct FindFieldVisitor; + friend class MemoryManager; friend class TypeFactory; - friend class base_internal::TypeHandleBase; + friend class base_internal::PersistentTypeHandle; friend class StructValue; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kStruct; } - StructType(const StructType&) = delete; StructType(StructType&&) = delete; - std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_STRUCT_TYPE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; }; diff --git a/base/types/timestamp_type.cc b/base/types/timestamp_type.cc index 7b1239689..bacca83ee 100644 --- a/base/types/timestamp_type.cc +++ b/base/types/timestamp_type.cc @@ -14,15 +14,28 @@ #include "base/types/timestamp_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(TimestampType); -const TimestampType& TimestampType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& TimestampType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt< + TimestampType>(&instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/timestamp_type.h b/base/types/timestamp_type.h index 20c7209b5..b0150e6bf 100644 --- a/base/types/timestamp_type.h +++ b/base/types/timestamp_type.h @@ -15,44 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_TIMESTAMP_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_TIMESTAMP_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class TimestampType final : public Type { +class TimestampValue; + +class TimestampType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kTimestamp; } + using Base::kKind; - absl::string_view name() const override { - return "google.protobuf.Timestamp"; - } + using Base::kName; - private: - friend class TimestampValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kTimestamp; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const TimestampType& Get(); + using Base::DebugString; - TimestampType() = default; + using Base::HashValue; - TimestampType(const TimestampType&) = delete; - TimestampType(TimestampType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(TimestampType, TimestampValue); }; -CEL_INTERNAL_TYPE_DECL(TimestampType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(TimestampType); } // namespace cel diff --git a/base/types/type_type.cc b/base/types/type_type.cc index a6468b5a7..be4106864 100644 --- a/base/types/type_type.cc +++ b/base/types/type_type.cc @@ -14,15 +14,28 @@ #include "base/types/type_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(TypeType); -const TypeType& TypeType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& TypeType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/type_type.h b/base/types/type_type.h index 711db96a0..a4f961dcd 100644 --- a/base/types/type_type.h +++ b/base/types/type_type.h @@ -15,43 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_TYPE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_TYPE_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -// TypeType represents the type of a type. -class TypeType final : public Type { +class TypeValue; + +class TypeType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kType; } + using Base::kKind; - absl::string_view name() const override { return "type"; } + using Base::kName; - private: - friend class TypeValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kType; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeType& Get(); + using Base::DebugString; - TypeType() = default; + using Base::HashValue; - TypeType(const TypeType&) = delete; - TypeType(TypeType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(TypeType, TypeValue); }; -CEL_INTERNAL_TYPE_DECL(TypeType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(TypeType); } // namespace cel diff --git a/base/types/uint_type.cc b/base/types/uint_type.cc index 1632b1d4c..14ca4a85e 100644 --- a/base/types/uint_type.cc +++ b/base/types/uint_type.cc @@ -14,15 +14,28 @@ #include "base/types/uint_type.h" -#include "internal/no_destructor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(UintType); -const UintType& UintType::Get() { - static const internal::NoDestructor instance; - return *instance; +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& UintType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt( + &instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); } } // namespace cel diff --git a/base/types/uint_type.h b/base/types/uint_type.h index 1af31d0f1..65555c92d 100644 --- a/base/types/uint_type.h +++ b/base/types/uint_type.h @@ -15,42 +15,39 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_UINT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_UINT_TYPE_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" #include "base/kind.h" #include "base/type.h" namespace cel { -class UintType final : public Type { +class UintValue; + +class UintType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + public: - Kind kind() const override { return Kind::kUint; } + using Base::kKind; - absl::string_view name() const override { return "uint"; } + using Base::kName; - private: - friend class UintValue; - friend class TypeFactory; - template - friend class internal::NoDestructor; - friend class base_internal::TypeHandleBase; + using Base::Is; + + using Base::kind; - // Called by base_internal::TypeHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Type& type) { return type.kind() == Kind::kUint; } + using Base::name; - ABSL_ATTRIBUTE_PURE_FUNCTION static const UintType& Get(); + using Base::DebugString; - UintType() = default; + using Base::HashValue; - UintType(const UintType&) = delete; - UintType(UintType&&) = delete; + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(UintType, UintValue); }; -CEL_INTERNAL_TYPE_DECL(UintType); +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(UintType); } // namespace cel diff --git a/base/value.cc b/base/value.cc index add9bdaa7..a16a490bb 100644 --- a/base/value.cc +++ b/base/value.cc @@ -15,19 +15,348 @@ #include "base/value.h" #include +#include #include +#include "absl/base/macros.h" +#include "base/values/bool_value.h" +#include "base/values/bytes_value.h" +#include "base/values/double_value.h" +#include "base/values/duration_value.h" +#include "base/values/enum_value.h" +#include "base/values/error_value.h" +#include "base/values/int_value.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" +#include "base/values/null_value.h" +#include "base/values/string_value.h" +#include "base/values/struct_value.h" +#include "base/values/timestamp_value.h" +#include "base/values/type_value.h" +#include "base/values/uint_value.h" + namespace cel { -std::pair Value::SizeAndAlignment() const { - // Currently most implementations of Value are not reference counted, so those - // that are override this and those that do not inherit this. Using 0 here - // will trigger runtime asserts in case of undefined behavior. - return std::pair(0, 0); +CEL_INTERNAL_VALUE_IMPL(Value); + +const Persistent& Value::type() const { + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->type().As(); + case Kind::kError: + return static_cast(this)->type().As(); + case Kind::kType: + return static_cast(this)->type().As(); + case Kind::kBool: + return static_cast(this)->type().As(); + case Kind::kInt: + return static_cast(this)->type().As(); + case Kind::kUint: + return static_cast(this)->type().As(); + case Kind::kDouble: + return static_cast(this)->type().As(); + case Kind::kString: + return static_cast(this)->type().As(); + case Kind::kBytes: + return static_cast(this)->type().As(); + case Kind::kEnum: + return static_cast(this)->type().As(); + case Kind::kDuration: + return static_cast(this)->type().As(); + case Kind::kTimestamp: + return static_cast(this)->type().As(); + case Kind::kList: + return static_cast(this)->type().As(); + case Kind::kMap: + return static_cast(this)->type().As(); + case Kind::kStruct: + return static_cast(this)->type().As(); + default: + ABSL_INTERNAL_UNREACHABLE; + } +} + +std::string Value::DebugString() const { + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->DebugString(); + case Kind::kError: + return static_cast(this)->DebugString(); + case Kind::kType: + return static_cast(this)->DebugString(); + case Kind::kBool: + return static_cast(this)->DebugString(); + case Kind::kInt: + return static_cast(this)->DebugString(); + case Kind::kUint: + return static_cast(this)->DebugString(); + case Kind::kDouble: + return static_cast(this)->DebugString(); + case Kind::kString: + return static_cast(this)->DebugString(); + case Kind::kBytes: + return static_cast(this)->DebugString(); + case Kind::kEnum: + return static_cast(this)->DebugString(); + case Kind::kDuration: + return static_cast(this)->DebugString(); + case Kind::kTimestamp: + return static_cast(this)->DebugString(); + case Kind::kList: + return static_cast(this)->DebugString(); + case Kind::kMap: + return static_cast(this)->DebugString(); + case Kind::kStruct: + return static_cast(this)->DebugString(); + default: + ABSL_INTERNAL_UNREACHABLE; + } +} + +void Value::HashValue(absl::HashState state) const { + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kError: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kType: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kBool: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kInt: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kUint: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kDouble: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kString: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kBytes: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kEnum: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kDuration: + return static_cast(this)->HashValue( + std::move(state)); + case Kind::kTimestamp: + return static_cast(this)->HashValue( + std::move(state)); + case Kind::kList: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kMap: + return static_cast(this)->HashValue(std::move(state)); + case Kind::kStruct: + return static_cast(this)->HashValue(std::move(state)); + default: + ABSL_INTERNAL_UNREACHABLE; + } +} + +bool Value::Equals(const Value& other) const { + if (this == &other) { + return true; + } + switch (kind()) { + case Kind::kNullType: + return static_cast(this)->Equals(other); + case Kind::kError: + return static_cast(this)->Equals(other); + case Kind::kType: + return static_cast(this)->Equals(other); + case Kind::kBool: + return static_cast(this)->Equals(other); + case Kind::kInt: + return static_cast(this)->Equals(other); + case Kind::kUint: + return static_cast(this)->Equals(other); + case Kind::kDouble: + return static_cast(this)->Equals(other); + case Kind::kString: + return static_cast(this)->Equals(other); + case Kind::kBytes: + return static_cast(this)->Equals(other); + case Kind::kEnum: + return static_cast(this)->Equals(other); + case Kind::kDuration: + return static_cast(this)->Equals(other); + case Kind::kTimestamp: + return static_cast(this)->Equals(other); + case Kind::kList: + return static_cast(this)->Equals(other); + case Kind::kMap: + return static_cast(this)->Equals(other); + case Kind::kStruct: + return static_cast(this)->Equals(other); + default: + ABSL_INTERNAL_UNREACHABLE; + } +} + +namespace base_internal { + +bool PersistentValueHandle::Equals(const PersistentValueHandle& other) const { + const auto* self = static_cast(data_.get()); + const auto* that = static_cast(other.data_.get()); + if (self == that) { + return true; + } + if (self == nullptr || that == nullptr) { + return false; + } + return *self == *that; +} + +void PersistentValueHandle::HashValue(absl::HashState state) const { + if (const auto* pointer = static_cast(data_.get()); + ABSL_PREDICT_TRUE(pointer != nullptr)) { + pointer->HashValue(std::move(state)); + } +} + +void PersistentValueHandle::CopyFrom(const PersistentValueHandle& other) { + // data_ is currently uninitialized. + auto locality = other.data_.locality(); + if (locality == DataLocality::kStoredInline && + !other.data_.IsTriviallyCopyable()) { + switch (other.data_.kind()) { + case Kind::kError: + data_.ConstructInline( + *static_cast(other.data_.get())); + break; + case Kind::kString: + data_.ConstructInline( + *static_cast(other.data_.get())); + break; + case Kind::kBytes: + data_.ConstructInline( + *static_cast(other.data_.get())); + break; + case Kind::kType: + data_.ConstructInline( + *static_cast(other.data_.get())); + break; + case Kind::kEnum: + data_.ConstructInline( + *static_cast(other.data_.get())); + break; + default: + ABSL_INTERNAL_UNREACHABLE; + } + } else { + // We can simply just copy the bytes. + data_.CopyFrom(other.data_); + if (locality == DataLocality::kReferenceCounted) { + Ref(); + } + } +} + +void PersistentValueHandle::MoveFrom(PersistentValueHandle& other) { + // data_ is currently uninitialized. + auto locality = other.data_.locality(); + if (locality == DataLocality::kStoredInline && + !other.data_.IsTriviallyCopyable()) { + switch (other.data_.kind()) { + case Kind::kError: + data_.ConstructInline( + std::move(*static_cast(other.data_.get()))); + break; + case Kind::kString: + data_.ConstructInline(std::move( + *static_cast(other.data_.get()))); + break; + case Kind::kBytes: + data_.ConstructInline( + std::move(*static_cast(other.data_.get()))); + break; + case Kind::kType: + data_.ConstructInline( + std::move(*static_cast(other.data_.get()))); + break; + case Kind::kEnum: + data_.ConstructInline( + std::move(*static_cast(other.data_.get()))); + break; + default: + ABSL_INTERNAL_UNREACHABLE; + } + other.Destruct(); + other.data_.Clear(); + } else { + // We can simply just copy the bytes. + data_.MoveFrom(other.data_); + } +} + +void PersistentValueHandle::CopyAssign(const PersistentValueHandle& other) { + // data_ is initialized. + Destruct(); + CopyFrom(other); +} + +void PersistentValueHandle::MoveAssign(PersistentValueHandle& other) { + // data_ is initialized. + Destruct(); + MoveFrom(other); } -void Value::CopyTo(Value& address) const {} +void PersistentValueHandle::Destruct() { + switch (data_.locality()) { + case DataLocality::kNull: + break; + case DataLocality::kStoredInline: + if (!data_.IsTriviallyDestructible()) { + switch (data_.kind()) { + case Kind::kError: + data_.Destruct(); + break; + case Kind::kString: + data_.Destruct(); + break; + case Kind::kBytes: + data_.Destruct(); + break; + case Kind::kType: + data_.Destruct(); + break; + case Kind::kEnum: + data_.Destruct(); + break; + default: + ABSL_INTERNAL_UNREACHABLE; + } + } + break; + case DataLocality::kReferenceCounted: + Unref(); + break; + case DataLocality::kArenaAllocated: + break; + } +} + +void PersistentValueHandle::Delete() const { + switch (data_.kind()) { + case Kind::kList: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kMap: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kStruct: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kString: + delete static_cast(static_cast(data_.get())); + break; + case Kind::kBytes: + delete static_cast(static_cast(data_.get())); + break; + default: + ABSL_INTERNAL_UNREACHABLE; + } +} -void Value::MoveTo(Value& address) {} +} // namespace base_internal } // namespace cel diff --git a/base/value.h b/base/value.h index 04c678de1..e878e59aa 100644 --- a/base/value.h +++ b/base/value.h @@ -19,30 +19,27 @@ #include #include #include +#include #include +#include "absl/base/attributes.h" +#include "absl/hash/hash.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/handle.h" -#include "base/internal/value.pre.h" // IWYU pragma: export +#include "base/internal/value.h" // IWYU pragma: export #include "base/kind.h" -#include "base/memory_manager.h" #include "base/type.h" +#include "base/types/null_type.h" +#include "internal/casts.h" // IWYU pragma: keep namespace cel { class Value; -class NullValue; class ErrorValue; -class BoolValue; -class IntValue; -class UintValue; -class DoubleValue; class BytesValue; class StringValue; -class DurationValue; -class TimestampValue; class EnumValue; class StructValue; class ListValue; @@ -50,91 +47,265 @@ class MapValue; class TypeValue; class ValueFactory; -namespace internal { -template -class NoDestructor; -} - -namespace interop_internal { -base_internal::StringValueRep GetStringValueRep( - const Persistent& value); -base_internal::BytesValueRep GetBytesValueRep( - const Persistent& value); -} // namespace interop_internal - // A representation of a CEL value that enables reflection and introspection of // values. -class Value : public base_internal::Resource { +class Value : public base_internal::Data { public: - // Returns the type of the value. If you only need the kind, prefer `kind()`. - virtual Persistent type() const = 0; + static bool Is(const Value& value ABSL_ATTRIBUTE_UNUSED) { return true; } // Returns the kind of the value. This is equivalent to `type().kind()` but // faster in many scenarios. As such it should be preffered when only the kind // is required. - virtual Kind kind() const { return type()->kind(); } + Kind kind() const { return base_internal::Metadata::For(this)->kind(); } - virtual std::string DebugString() const = 0; + // Returns the type of the value. If you only need the kind, prefer `kind()`. + const Persistent& type() const; + + std::string DebugString() const; - // Called by base_internal::ValueHandleBase. - // Note GCC does not consider a friend member as a member of a friend. - virtual bool Equals(const Value& other) const = 0; + void HashValue(absl::HashState state) const; - // Called by base_internal::ValueHandleBase. - // Note GCC does not consider a friend member as a member of a friend. - virtual void HashValue(absl::HashState state) const = 0; + bool Equals(const Value& other) const; private: - friend class NullValue; friend class ErrorValue; - friend class BoolValue; - friend class IntValue; - friend class UintValue; - friend class DoubleValue; friend class BytesValue; friend class StringValue; - friend class DurationValue; - friend class TimestampValue; friend class EnumValue; friend class StructValue; friend class ListValue; friend class MapValue; friend class TypeValue; - friend class base_internal::ValueHandleBase; - friend class base_internal::StringBytesValue; - friend class base_internal::ExternalDataBytesValue; - friend class base_internal::StringStringValue; - friend class base_internal::ExternalDataStringValue; + friend class base_internal::PersistentValueHandle; + template + friend class base_internal::SimpleValue; Value() = default; Value(const Value&) = default; Value(Value&&) = default; + Value& operator=(const Value&) = default; + Value& operator=(Value&&) = default; +}; + +template +H AbslHashValue(H state, const Value& value) { + value.HashValue(absl::HashState::Create(&state)); + return state; +} + +inline bool operator==(const Value& lhs, const Value& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator!=(const Value& lhs, const Value& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace cel + +// ----------------------------------------------------------------------------- +// Internal implementation details. + +namespace cel { + +namespace base_internal { + +class PersistentValueHandle final { + public: + PersistentValueHandle() = default; + + template + explicit PersistentValueHandle(absl::in_place_type_t in_place_type, + Args&&... args) { + data_.ConstructInline(std::forward(args)...); + } + + explicit PersistentValueHandle(const Value& value) { + data_.ConstructHeap(value); + } + + PersistentValueHandle(const PersistentValueHandle& other) { CopyFrom(other); } + + PersistentValueHandle(PersistentValueHandle&& other) { MoveFrom(other); } + + ~PersistentValueHandle() { Destruct(); } + + PersistentValueHandle& operator=(const PersistentValueHandle& other) { + if (this != &other) { + CopyAssign(other); + } + return *this; + } + + PersistentValueHandle& operator=(PersistentValueHandle&& other) { + if (this != &other) { + MoveAssign(other); + } + return *this; + } + + Value* get() const { return static_cast(data_.get()); } + + explicit operator bool() const { return !data_.IsNull(); } + + bool Equals(const PersistentValueHandle& other) const; + + void HashValue(absl::HashState state) const; + + private: + void CopyFrom(const PersistentValueHandle& other); + + void MoveFrom(PersistentValueHandle& other); + + void CopyAssign(const PersistentValueHandle& other); + + void MoveAssign(PersistentValueHandle& other); + + void Ref() const { data_.Ref(); } + + void Unref() const { + if (data_.Unref()) { + Delete(); + } + } - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return true; } + void Destruct(); - // For non-inlined values that are reference counted, this is the result of - // `sizeof` and `alignof` for the most derived class. - std::pair SizeAndAlignment() const override; + void Delete() const; - // Expose to some value implementations using friendship. - using base_internal::Resource::Ref; - using base_internal::Resource::Unref; + AnyValue data_; +}; + +template +H AbslHashValue(H state, const PersistentValueHandle& handle) { + handle.HashValue(absl::HashState::Create(&state)); + return state; +} + +inline bool operator==(const PersistentValueHandle& lhs, + const PersistentValueHandle& rhs) { + return lhs.Equals(rhs); +} + +inline bool operator!=(const PersistentValueHandle& lhs, + const PersistentValueHandle& rhs) { + return !operator==(lhs, rhs); +} + +// Specialization for Value providing the implementation to `Persistent`. +template <> +struct HandleTraits { + using handle_type = PersistentValueHandle; +}; + +// Partial specialization for `Persistent` for all classes derived from Value. +template +struct HandleTraits && + !std::is_same_v)>> + final : public HandleTraits {}; + +template +class SimpleValue : public Value, InlineData { + public: + static constexpr Kind kKind = T::kKind; + + static bool Is(const Value& value) { return value.kind() == kKind; } + + explicit SimpleValue(U value) : value_(value) {} + + SimpleValue(const SimpleValue&) = default; + SimpleValue(SimpleValue&&) = default; + SimpleValue& operator=(const SimpleValue&) = default; + SimpleValue& operator=(SimpleValue&&) = default; + + constexpr Kind kind() const { return kKind; } + + const Persistent& type() const { return T::Get(); } + + void HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); + } + + bool Equals(const Value& other) const { + return type() == other.type() && + value() == static_cast&>(other).value(); + } - // Called by base_internal::ValueHandleBase for inlined values. - virtual void CopyTo(Value& address) const; + constexpr U value() const { return value_; } - // Called by base_internal::ValueHandleBase for inlined values. - virtual void MoveTo(Value& address); + private: + friend class PersistentValueHandle; + + static constexpr uintptr_t kVirtualPointer = + kStoredInline | + (std::is_trivially_copyable_v ? kTriviallyCopyable : 0) | + (std::is_trivially_destructible_v ? kTriviallyDestructible : 0) | + (static_cast(kKind) << kKindShift); + + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; + U value_; +}; + +template <> +class SimpleValue : public Value, InlineData { + public: + static constexpr Kind kKind = Kind::kNullType; + + static bool Is(const Value& value) { return value.kind() == kKind; } + + SimpleValue() {} + + SimpleValue(const SimpleValue&) = default; + SimpleValue(SimpleValue&&) = default; + SimpleValue& operator=(const SimpleValue&) = default; + SimpleValue& operator=(SimpleValue&&) = default; + + constexpr Kind kind() const { return kKind; } + + const Persistent& type() const { return NullType::Get(); } + + void HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), 0); + } + + bool Equals(const Value& other) const { return kind() == other.kind(); } + + private: + friend class PersistentValueHandle; + + static constexpr uintptr_t kVirtualPointer = + kStoredInline | kTriviallyCopyable | kTriviallyDestructible | + (static_cast(kKind) << kKindShift); + + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; }; +} // namespace base_internal + +CEL_INTERNAL_VALUE_DECL(Value); + } // namespace cel -// value.pre.h forward declares types so they can be friended above. The types -// themselves need to be defined after everything else as they need to access or -// derive from the above types. We do this in value.post.h to avoid mudying this -// header and making it difficult to read. -#include "base/internal/value.post.h" // IWYU pragma: export +#define CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(value_class) \ + static_assert(std::is_trivially_copyable_v, \ + #value_class " must be trivially copyable"); \ + static_assert(std::is_trivially_destructible_v, \ + #value_class " must be trivially destructible"); \ + \ + CEL_INTERNAL_VALUE_DECL(value_class) + +#define CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(value_class) \ + private: \ + friend class ValueFactory; \ + friend class base_internal::PersistentValueHandle; \ + template \ + friend class base_internal::AnyData; \ + \ + value_class() = default; \ + value_class(const value_class&) = default; \ + value_class(value_class&&) = default; \ + value_class& operator=(const value_class&) = default; \ + value_class& operator=(value_class&&) = default #endif // THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ diff --git a/base/value_factory.cc b/base/value_factory.cc index 078294442..4f5025189 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -30,8 +30,6 @@ namespace cel { namespace { -using base_internal::ExternalDataBytesValue; -using base_internal::ExternalDataStringValue; using base_internal::InlinedCordBytesValue; using base_internal::InlinedCordStringValue; using base_internal::InlinedStringViewBytesValue; @@ -40,14 +38,6 @@ using base_internal::PersistentHandleFactory; using base_internal::StringBytesValue; using base_internal::StringStringValue; -template -bool CanPerformZeroCopy(MemoryManager& memory_manager, - const Persistent& handle) { - return base_internal::IsManagedHandle(handle) && - std::addressof(memory_manager) == - std::addressof(base_internal::GetMemoryManager(handle)); -} - } // namespace Persistent NullValue::Get(ValueFactory& value_factory) { @@ -55,9 +45,7 @@ Persistent NullValue::Get(ValueFactory& value_factory) { } Persistent ValueFactory::GetNullValue() { - return Persistent( - PersistentHandleFactory::MakeUnmanaged( - NullValue::Get())); + return PersistentHandleFactory::Make(); } Persistent ValueFactory::CreateErrorValue( @@ -113,26 +101,12 @@ Persistent StringValue::Empty(ValueFactory& value_factory) { } absl::StatusOr> StringValue::Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs) { + ValueFactory& value_factory, const StringValue& lhs, + const StringValue& rhs) { absl::Cord cord; - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); - size_t size = 0; - size_t lhs_size = lhs->size_.load(std::memory_order_relaxed); - if (lhs_size != 0 && !lhs->empty()) { - size_t rhs_size = rhs->size_.load(std::memory_order_relaxed); - if (rhs_size != 0 && !rhs->empty()) { - size = lhs_size + rhs_size; - } - } - return value_factory.CreateStringValue(std::move(cord), size); + cord.Append(lhs.ToCord()); + cord.Append(rhs.ToCord()); + return value_factory.CreateStringValue(std::move(cord)); } Persistent BytesValue::Empty(ValueFactory& value_factory) { @@ -140,17 +114,10 @@ Persistent BytesValue::Empty(ValueFactory& value_factory) { } absl::StatusOr> BytesValue::Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs) { + ValueFactory& value_factory, const BytesValue& lhs, const BytesValue& rhs) { absl::Cord cord; - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); - // We can only use the potential zero-copy path if the memory managers are - // the same. Otherwise we need to escape the original memory manager scope. - cord.Append( - rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); + cord.Append(lhs.ToCord()); + cord.Append(rhs.ToCord()); return value_factory.CreateBytesValue(std::move(cord)); } @@ -191,12 +158,6 @@ absl::StatusOr> StructValue::New( ValueFactory& value_factory) { TypedStructValueFactory factory(value_factory, struct_type); CEL_ASSIGN_OR_RETURN(auto struct_value, struct_type->NewInstance(factory)); - if (!struct_value->type_) { - // In case somebody is caching, we avoid setting the type_ if it has already - // been set, to avoid a race condition where one CPU sees a half written - // pointer. - const_cast(*struct_value).type_ = struct_type; - } return struct_value; } @@ -247,7 +208,7 @@ absl::StatusOr> ValueFactory::CreateStringValue( "Illegal byte sequence in UTF-8 encoded string"); } return PersistentHandleFactory::Make( - memory_manager(), count, std::move(value)); + memory_manager(), std::move(value)); } absl::StatusOr> ValueFactory::CreateStringValue( @@ -260,7 +221,8 @@ absl::StatusOr> ValueFactory::CreateStringValue( return absl::InvalidArgumentError( "Illegal byte sequence in UTF-8 encoded string"); } - return CreateStringValue(std::move(value), count); + return PersistentHandleFactory::Make< + InlinedCordStringValue>(std::move(value)); } absl::StatusOr> @@ -293,32 +255,11 @@ Persistent ValueFactory::GetEmptyBytesValue() { InlinedStringViewBytesValue>(absl::string_view()); } -absl::StatusOr> ValueFactory::CreateBytesValue( - base_internal::ExternalData value) { - return PersistentHandleFactory::Make< - ExternalDataBytesValue>(memory_manager(), std::move(value)); -} - Persistent ValueFactory::GetEmptyStringValue() { return PersistentHandleFactory::Make< InlinedStringViewStringValue>(absl::string_view()); } -absl::StatusOr> ValueFactory::CreateStringValue( - absl::Cord value, size_t size) { - if (value.empty()) { - return GetEmptyStringValue(); - } - return PersistentHandleFactory::Make< - InlinedCordStringValue>(size, std::move(value)); -} - -absl::StatusOr> ValueFactory::CreateStringValue( - base_internal::ExternalData value) { - return PersistentHandleFactory::Make< - ExternalDataStringValue>(memory_manager(), std::move(value)); -} - absl::StatusOr> ValueFactory::CreateStringValueFromView(absl::string_view value) { return PersistentHandleFactory::Make< diff --git a/base/value_factory.h b/base/value_factory.h index 359e5ef6b..7e1c6fc86 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -120,10 +120,8 @@ class ValueFactory final { std::forward(releaser)(); return GetEmptyBytesValue(); } - return CreateBytesValue(base_internal::ExternalData( - static_cast(value.data()), value.size(), - std::make_unique( - std::forward(releaser)))); + return CreateBytesValue( + absl::MakeCordFromExternal(value, std::forward(releaser))); } Persistent GetStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { @@ -154,10 +152,8 @@ class ValueFactory final { std::forward(releaser)(); return GetEmptyStringValue(); } - return CreateStringValue(base_internal::ExternalData( - static_cast(value.data()), value.size(), - std::make_unique( - std::forward(releaser)))); + return CreateStringValue( + absl::MakeCordFromExternal(value, std::forward(releaser))); } absl::StatusOr> CreateDurationValue( @@ -166,13 +162,19 @@ class ValueFactory final { absl::StatusOr> CreateTimestampValue( absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - template - EnableIfBaseOfT>> CreateEnumValue( + absl::StatusOr> CreateEnumValue( const Persistent& enum_type, - Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - return base_internal::PersistentHandleFactory::template Make< - std::remove_const_t>(memory_manager(), enum_type, - std::forward(args)...); + int64_t number) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory< + const EnumValue>::template Make(enum_type, number); + } + + template + std::enable_if_t, + absl::StatusOr>> + CreateEnumValue(const Persistent& enum_type, + T value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateEnumValue(enum_type, static_cast(value)); } template @@ -222,21 +224,12 @@ class ValueFactory final { Persistent GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; - absl::StatusOr> CreateBytesValue( - base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - absl::StatusOr> CreateBytesValueFromView( absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; Persistent GetEmptyStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; - absl::StatusOr> CreateStringValue( - absl::Cord value, size_t size) ABSL_ATTRIBUTE_LIFETIME_BOUND; - - absl::StatusOr> CreateStringValue( - base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - absl::StatusOr> CreateStringValueFromView( absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -257,11 +250,16 @@ class TypedEnumValueFactory final { const Persistent& enum_type ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_factory_(value_factory), enum_type_(enum_type) {} - template - EnableIfBaseOfT>> CreateEnumValue( - Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - return value_factory_.CreateEnumValue(enum_type_, - std::forward(args)...); + absl::StatusOr> CreateEnumValue(int64_t number) + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_factory_.CreateEnumValue(enum_type_, number); + } + + template + std::enable_if_t, + absl::StatusOr>> + CreateEnumValue(T value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateEnumValue(static_cast(value)); } private: diff --git a/base/value_test.cc b/base/value_test.cc index 926d7472a..23fa9c575 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -53,40 +53,6 @@ enum class TestEnum { kValue2 = 2, }; -class TestEnumValue final : public EnumValue { - public: - explicit TestEnumValue(const Persistent& type, - TestEnum test_enum) - : EnumValue(type), test_enum_(test_enum) {} - - std::string DebugString() const override { return std::string(name()); } - - absl::string_view name() const override { - switch (test_enum_) { - case TestEnum::kValue1: - return "VALUE1"; - case TestEnum::kValue2: - return "VALUE2"; - } - } - - int64_t number() const override { - switch (test_enum_) { - case TestEnum::kValue1: - return 1; - case TestEnum::kValue2: - return 2; - } - } - - private: - CEL_DECLARE_ENUM_VALUE(TestEnumValue); - - TestEnum test_enum_; -}; - -CEL_IMPLEMENT_ENUM_VALUE(TestEnumValue); - class TestEnumType final : public EnumType { public: using EnumType::EnumType; @@ -97,9 +63,9 @@ class TestEnumType final : public EnumType { absl::StatusOr> NewInstanceByName( TypedEnumValueFactory& factory, absl::string_view name) const override { if (name == "VALUE1") { - return factory.CreateEnumValue(TestEnum::kValue1); + return factory.CreateEnumValue(TestEnum::kValue1); } else if (name == "VALUE2") { - return factory.CreateEnumValue(TestEnum::kValue2); + return factory.CreateEnumValue(TestEnum::kValue2); } return absl::NotFoundError(""); } @@ -108,9 +74,9 @@ class TestEnumType final : public EnumType { TypedEnumValueFactory& factory, int64_t number) const override { switch (number) { case 1: - return factory.CreateEnumValue(TestEnum::kValue1); + return factory.CreateEnumValue(TestEnum::kValue1); case 2: - return factory.CreateEnumValue(TestEnum::kValue2); + return factory.CreateEnumValue(TestEnum::kValue2); default: return absl::NotFoundError(""); } @@ -122,7 +88,14 @@ class TestEnumType final : public EnumType { } absl::StatusOr FindConstantByNumber(int64_t number) const override { - return absl::UnimplementedError(""); + switch (number) { + case 1: + return Constant("VALUE1", 1); + case 2: + return Constant("VALUE2", 2); + default: + return absl::NotFoundError(""); + } } private: @@ -382,8 +355,7 @@ class TestListValue final : public ListValue { private: bool Equals(const Value& other) const override { return Is(other) && - elements_ == - internal::down_cast(other).elements_; + elements_ == static_cast(other).elements_; } void HashValue(absl::HashState state) const override { @@ -446,7 +418,7 @@ class TestMapValue final : public MapValue { private: bool Equals(const Value& other) const override { return Is(other) && - entries_ == internal::down_cast(other).entries_; + entries_ == static_cast(other).entries_; } void HashValue(absl::HashState state) const override { @@ -507,12 +479,6 @@ class BaseValueTest using ValueTest = BaseValueTest<>; -TEST(Value, HandleSize) { - // Advisory test to ensure we attempt to keep the size of Value handles under - // 32 bytes. As of the time of writing they are 24 bytes. - EXPECT_LE(sizeof(base_internal::ValueHandleData), 32); -} - TEST(Value, PersistentHandleTypeTraits) { EXPECT_TRUE(std::is_default_constructible_v>); EXPECT_TRUE(std::is_copy_constructible_v>); @@ -1307,49 +1273,49 @@ TEST_P(BytesConcatTest, Concat) { ValueFactory value_factory(type_manager); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeStringBytes(value_factory, test_case().lhs), - MakeStringBytes(value_factory, test_case().rhs))) - ->Equals(test_case().lhs + test_case().rhs)); - EXPECT_TRUE( - Must(BytesValue::Concat(value_factory, - MakeStringBytes(value_factory, test_case().lhs), - MakeCordBytes(value_factory, test_case().rhs))) - ->Equals(test_case().lhs + test_case().rhs)); - EXPECT_TRUE( - Must(BytesValue::Concat( - value_factory, MakeStringBytes(value_factory, test_case().lhs), - MakeExternalBytes(value_factory, test_case().rhs))) - ->Equals(test_case().lhs + test_case().rhs)); - EXPECT_TRUE( - Must(BytesValue::Concat(value_factory, - MakeCordBytes(value_factory, test_case().lhs), - MakeStringBytes(value_factory, test_case().rhs))) + *MakeStringBytes(value_factory, test_case().lhs), + *MakeStringBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeCordBytes(value_factory, test_case().lhs), - MakeCordBytes(value_factory, test_case().rhs))) + *MakeStringBytes(value_factory, test_case().lhs), + *MakeCordBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat( - value_factory, MakeCordBytes(value_factory, test_case().lhs), - MakeExternalBytes(value_factory, test_case().rhs))) + value_factory, *MakeStringBytes(value_factory, test_case().lhs), + *MakeExternalBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeExternalBytes(value_factory, test_case().lhs), - MakeStringBytes(value_factory, test_case().rhs))) + *MakeCordBytes(value_factory, test_case().lhs), + *MakeStringBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeExternalBytes(value_factory, test_case().lhs), - MakeCordBytes(value_factory, test_case().rhs))) + *MakeCordBytes(value_factory, test_case().lhs), + *MakeCordBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat( - value_factory, MakeExternalBytes(value_factory, test_case().lhs), - MakeExternalBytes(value_factory, test_case().rhs))) + value_factory, *MakeCordBytes(value_factory, test_case().lhs), + *MakeExternalBytes(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(BytesValue::Concat( + value_factory, + *MakeExternalBytes(value_factory, test_case().lhs), + *MakeStringBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(BytesValue::Concat( + value_factory, + *MakeExternalBytes(value_factory, test_case().lhs), + *MakeCordBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(BytesValue::Concat( + value_factory, + *MakeExternalBytes(value_factory, test_case().lhs), + *MakeExternalBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); } INSTANTIATE_TEST_SUITE_P( @@ -1438,31 +1404,31 @@ TEST_P(BytesEqualsTest, Equals) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) - ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + ->Equals(*MakeStringBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) - ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + ->Equals(*MakeCordBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) - ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + ->Equals(*MakeExternalBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) - ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + ->Equals(*MakeStringBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) - ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + ->Equals(*MakeCordBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) - ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + ->Equals(*MakeExternalBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) - ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + ->Equals(*MakeStringBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) - ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + ->Equals(*MakeCordBytes(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) - ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + ->Equals(*MakeExternalBytes(value_factory, test_case().rhs)), test_case().equals); } @@ -1496,43 +1462,45 @@ TEST_P(BytesCompareTest, Equals) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); + EXPECT_EQ( + NormalizeCompareResult( + MakeStringBytes(value_factory, test_case().lhs) + ->Compare(*MakeStringBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeStringBytes(value_factory, test_case().lhs) - ->Compare(MakeStringBytes(value_factory, test_case().rhs))), - test_case().compare); - EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(value_factory, test_case().lhs) - ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + ->Compare(*MakeCordBytes(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeStringBytes(value_factory, test_case().lhs) - ->Compare(MakeExternalBytes(value_factory, test_case().rhs))), + ->Compare(*MakeExternalBytes(value_factory, test_case().rhs))), test_case().compare); - EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(value_factory, test_case().lhs) - ->Compare(MakeStringBytes(value_factory, test_case().rhs))), + EXPECT_EQ(NormalizeCompareResult(MakeCordBytes(value_factory, test_case().lhs) + ->Compare(*MakeStringBytes( + value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeCordBytes(value_factory, test_case().lhs) - ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + ->Compare(*MakeCordBytes(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ(NormalizeCompareResult(MakeCordBytes(value_factory, test_case().lhs) - ->Compare(MakeExternalBytes( + ->Compare(*MakeExternalBytes( value_factory, test_case().rhs))), test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeExternalBytes(value_factory, test_case().lhs) + ->Compare(*MakeStringBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeExternalBytes(value_factory, test_case().lhs) - ->Compare(MakeStringBytes(value_factory, test_case().rhs))), - test_case().compare); - EXPECT_EQ(NormalizeCompareResult( - MakeExternalBytes(value_factory, test_case().lhs) - ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + ->Compare(*MakeCordBytes(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeExternalBytes(value_factory, test_case().lhs) - ->Compare(MakeExternalBytes(value_factory, test_case().rhs))), + ->Compare(*MakeExternalBytes(value_factory, test_case().rhs))), test_case().compare); } @@ -1664,48 +1632,48 @@ TEST_P(StringConcatTest, Concat) { ValueFactory value_factory(type_manager); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeStringString(value_factory, test_case().lhs), - MakeStringString(value_factory, test_case().rhs))) + value_factory, *MakeStringString(value_factory, test_case().lhs), + *MakeStringString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( - Must(StringValue::Concat(value_factory, - MakeStringString(value_factory, test_case().lhs), - MakeCordString(value_factory, test_case().rhs))) + Must(StringValue::Concat( + value_factory, *MakeStringString(value_factory, test_case().lhs), + *MakeCordString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeStringString(value_factory, test_case().lhs), - MakeExternalString(value_factory, test_case().rhs))) + value_factory, *MakeStringString(value_factory, test_case().lhs), + *MakeExternalString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeCordString(value_factory, test_case().lhs), - MakeStringString(value_factory, test_case().rhs))) + value_factory, *MakeCordString(value_factory, test_case().lhs), + *MakeStringString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat(value_factory, - MakeCordString(value_factory, test_case().lhs), - MakeCordString(value_factory, test_case().rhs))) + *MakeCordString(value_factory, test_case().lhs), + *MakeCordString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeCordString(value_factory, test_case().lhs), - MakeExternalString(value_factory, test_case().rhs))) + value_factory, *MakeCordString(value_factory, test_case().lhs), + *MakeExternalString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE(Must(StringValue::Concat( value_factory, - MakeExternalString(value_factory, test_case().lhs), - MakeStringString(value_factory, test_case().rhs))) + *MakeExternalString(value_factory, test_case().lhs), + *MakeStringString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE(Must(StringValue::Concat( value_factory, - MakeExternalString(value_factory, test_case().lhs), - MakeCordString(value_factory, test_case().rhs))) + *MakeExternalString(value_factory, test_case().lhs), + *MakeCordString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE(Must(StringValue::Concat( value_factory, - MakeExternalString(value_factory, test_case().lhs), - MakeExternalString(value_factory, test_case().rhs))) + *MakeExternalString(value_factory, test_case().lhs), + *MakeExternalString(value_factory, test_case().rhs))) ->Equals(test_case().lhs + test_case().rhs)); } @@ -1795,31 +1763,31 @@ TEST_P(StringEqualsTest, Equals) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) - ->Equals(MakeStringString(value_factory, test_case().rhs)), + ->Equals(*MakeStringString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) - ->Equals(MakeCordString(value_factory, test_case().rhs)), + ->Equals(*MakeCordString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) - ->Equals(MakeExternalString(value_factory, test_case().rhs)), + ->Equals(*MakeExternalString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) - ->Equals(MakeStringString(value_factory, test_case().rhs)), + ->Equals(*MakeStringString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) - ->Equals(MakeCordString(value_factory, test_case().rhs)), + ->Equals(*MakeCordString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) - ->Equals(MakeExternalString(value_factory, test_case().rhs)), + ->Equals(*MakeExternalString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) - ->Equals(MakeStringString(value_factory, test_case().rhs)), + ->Equals(*MakeStringString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) - ->Equals(MakeCordString(value_factory, test_case().rhs)), + ->Equals(*MakeCordString(value_factory, test_case().rhs)), test_case().equals); EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) - ->Equals(MakeExternalString(value_factory, test_case().rhs)), + ->Equals(*MakeExternalString(value_factory, test_case().rhs)), test_case().equals); } @@ -1854,44 +1822,44 @@ TEST_P(StringCompareTest, Equals) { EXPECT_EQ( NormalizeCompareResult( MakeStringString(value_factory, test_case().lhs) - ->Compare(MakeStringString(value_factory, test_case().rhs))), + ->Compare(*MakeStringString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeStringString(value_factory, test_case().lhs) - ->Compare(MakeCordString(value_factory, test_case().rhs))), + ->Compare(*MakeCordString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeStringString(value_factory, test_case().lhs) - ->Compare(MakeExternalString(value_factory, test_case().rhs))), + ->Compare(*MakeExternalString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeCordString(value_factory, test_case().lhs) - ->Compare(MakeStringString(value_factory, test_case().rhs))), + ->Compare(*MakeStringString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeCordString(value_factory, test_case().lhs) - ->Compare(MakeCordString(value_factory, test_case().rhs))), + ->Compare(*MakeCordString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeCordString(value_factory, test_case().lhs) - ->Compare(MakeExternalString(value_factory, test_case().rhs))), + ->Compare(*MakeExternalString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeExternalString(value_factory, test_case().lhs) - ->Compare(MakeStringString(value_factory, test_case().rhs))), + ->Compare(*MakeStringString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ(NormalizeCompareResult( MakeExternalString(value_factory, test_case().lhs) - ->Compare(MakeCordString(value_factory, test_case().rhs))), + ->Compare(*MakeCordString(value_factory, test_case().rhs))), test_case().compare); EXPECT_EQ( NormalizeCompareResult( MakeExternalString(value_factory, test_case().lhs) - ->Compare(MakeExternalString(value_factory, test_case().rhs))), + ->Compare(*MakeExternalString(value_factory, test_case().rhs))), test_case().compare); } @@ -2005,7 +1973,6 @@ TEST_P(ValueTest, Enum) { auto one_value, EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); EXPECT_TRUE(one_value.Is()); - EXPECT_TRUE(one_value.Is()); EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(EnumValue::New(enum_type, value_factory, @@ -2019,7 +1986,6 @@ TEST_P(ValueTest, Enum) { auto two_value, EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE2"))); EXPECT_TRUE(two_value.Is()); - EXPECT_TRUE(two_value.Is()); EXPECT_FALSE(two_value.Is()); EXPECT_EQ(two_value, two_value); EXPECT_EQ(two_value->kind(), Kind::kEnum); diff --git a/base/values/bool_value.cc b/base/values/bool_value.cc index 6cec63cf9..9ab5e6252 100644 --- a/base/values/bool_value.cc +++ b/base/values/bool_value.cc @@ -15,43 +15,13 @@ #include "base/values/bool_value.h" #include -#include - -#include "base/types/bool_type.h" -#include "internal/casts.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent BoolValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - BoolType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(BoolValue); std::string BoolValue::DebugString() const { return value() ? "true" : "false"; } -void BoolValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(BoolValue, *this, address); -} - -void BoolValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(BoolValue, *this, address); -} - -bool BoolValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void BoolValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/bool_value.h b/base/values/bool_value.h index 5b34f7e1c..4aad27c61 100644 --- a/base/values/bool_value.h +++ b/base/values/bool_value.h @@ -17,55 +17,44 @@ #include -#include "absl/hash/hash.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/bool_type.h" #include "base/value.h" namespace cel { -class ValueFactory; +class BoolValue final : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class BoolValue final : public Value, public base_internal::ResourceInlined { public: - static Persistent False(ValueFactory& value_factory); + using Base::kKind; - static Persistent True(ValueFactory& value_factory); + using Base::Is; - Persistent type() const override; - - Kind kind() const override { return Kind::kBool; } + static Persistent False(ValueFactory& value_factory); - std::string DebugString() const override; + static Persistent True(ValueFactory& value_factory); - constexpr bool value() const { return value_; } + using Base::kind; - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + using Base::type; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kBool; } + std::string DebugString() const; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit BoolValue(bool value) : value_(value) {} + using Base::HashValue; - BoolValue() = delete; + using Base::Equals; - BoolValue(const BoolValue&) = default; - BoolValue(BoolValue&&) = default; + using Base::value; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - bool value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(BoolValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(BoolValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_BOOL_VALUE_H_ diff --git a/base/values/bytes_value.cc b/base/values/bytes_value.cc index d3f0d739c..f8f94cef6 100644 --- a/base/values/bytes_value.cc +++ b/base/values/bytes_value.cc @@ -17,15 +17,18 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/strings/cord.h" +#include "base/internal/data.h" #include "base/types/bytes_type.h" -#include "internal/casts.h" #include "internal/strings.h" namespace cel { -namespace { +CEL_INTERNAL_VALUE_IMPL(BytesValue); -using base_internal::PersistentHandleFactory; +namespace { struct BytesValueDebugStringVisitor final { std::string operator()(absl::string_view value) const { @@ -169,11 +172,6 @@ class HashValueVisitor final { } // namespace -Persistent BytesValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - BytesType::Get()); -} - size_t BytesValue::size() const { return absl::visit(BytesValueSizeVisitor{}, rep()); } @@ -188,8 +186,8 @@ bool BytesValue::Equals(const absl::Cord& bytes) const { return absl::visit(EqualsVisitor(bytes), rep()); } -bool BytesValue::Equals(const Persistent& bytes) const { - return absl::visit(EqualsVisitor(*this), bytes->rep()); +bool BytesValue::Equals(const BytesValue& bytes) const { + return absl::visit(EqualsVisitor(*this), bytes.rep()); } int BytesValue::Compare(absl::string_view bytes) const { @@ -200,14 +198,43 @@ int BytesValue::Compare(const absl::Cord& bytes) const { return absl::visit(CompareVisitor(bytes), rep()); } -int BytesValue::Compare(const Persistent& bytes) const { - return absl::visit(CompareVisitor(*this), bytes->rep()); +int BytesValue::Compare(const BytesValue& bytes) const { + return absl::visit(CompareVisitor(*this), bytes.rep()); } std::string BytesValue::ToString() const { return absl::visit(ToStringVisitor{}, rep()); } +absl::Cord BytesValue::ToCord() const { + switch (base_internal::Metadata::For(this)->locality()) { + case base_internal::DataLocality::kNull: + return absl::Cord(); + case base_internal::DataLocality::kStoredInline: + if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + return absl::MakeCordFromExternal( + static_cast(this) + ->value_, + []() {}); + } else { + return static_cast(this) + ->value_; + } + case base_internal::DataLocality::kReferenceCounted: + base_internal::Metadata::For(this)->Ref(); + return absl::MakeCordFromExternal( + static_cast(this)->value_, + [this]() { + if (base_internal::Metadata::For(this)->Unref()) { + delete static_cast(this); + } + }); + case base_internal::DataLocality::kArenaAllocated: + return absl::Cord( + static_cast(this)->value_); + } +} + std::string BytesValue::DebugString() const { return absl::visit(BytesValueDebugStringVisitor{}, rep()); } @@ -215,7 +242,7 @@ std::string BytesValue::DebugString() const { bool BytesValue::Equals(const Value& other) const { return kind() == other.kind() && absl::visit(EqualsVisitor(*this), - internal::down_cast(other).rep()); + static_cast(other).rep()); } void BytesValue::HashValue(absl::HashState state) const { @@ -224,81 +251,42 @@ void BytesValue::HashValue(absl::HashState state) const { rep()); } -namespace base_internal { - -absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return value_; -} - -void InlinedCordBytesValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(InlinedCordBytesValue, *this, address); -} - -void InlinedCordBytesValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(InlinedCordBytesValue, *this, address); -} - -typename InlinedCordBytesValue::Rep InlinedCordBytesValue::rep() const { - return Rep(absl::in_place_type>, - std::cref(value_)); -} - -absl::Cord InlinedStringViewBytesValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return absl::Cord(value_); -} - -void InlinedStringViewBytesValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(InlinedStringViewBytesValue, *this, address); -} - -void InlinedStringViewBytesValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(InlinedStringViewBytesValue, *this, address); -} - -typename InlinedStringViewBytesValue::Rep InlinedStringViewBytesValue::rep() - const { - return Rep(absl::in_place_type, value_); -} - -std::pair StringBytesValue::SizeAndAlignment() const { - return std::make_pair(sizeof(StringBytesValue), alignof(StringBytesValue)); -} - -absl::Cord StringBytesValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal(absl::string_view(value_), - [this]() { Unref(); }); +base_internal::BytesValueRep BytesValue::rep() const { + switch (base_internal::Metadata::For(this)->locality()) { + case base_internal::DataLocality::kNull: + return base_internal::BytesValueRep(); + case base_internal::DataLocality::kStoredInline: + if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + return base_internal::BytesValueRep( + absl::in_place_type, + static_cast(this) + ->value_); + } else { + return base_internal::BytesValueRep( + absl::in_place_type>, + std::cref( + static_cast(this) + ->value_)); + } + case base_internal::DataLocality::kReferenceCounted: + ABSL_FALLTHROUGH_INTENDED; + case base_internal::DataLocality::kArenaAllocated: + return base_internal::BytesValueRep( + absl::in_place_type, + absl::string_view( + static_cast(this) + ->value_)); } - return absl::Cord(value_); -} - -typename StringBytesValue::Rep StringBytesValue::rep() const { - return Rep(absl::in_place_type, absl::string_view(value_)); -} - -std::pair ExternalDataBytesValue::SizeAndAlignment() const { - return std::make_pair(sizeof(ExternalDataBytesValue), - alignof(ExternalDataBytesValue)); } -absl::Cord ExternalDataBytesValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal( - absl::string_view(static_cast(value_.data), value_.size), - [this]() { Unref(); }); - } - return absl::Cord( - absl::string_view(static_cast(value_.data), value_.size)); -} +namespace base_internal { -typename ExternalDataBytesValue::Rep ExternalDataBytesValue::rep() const { - return Rep( - absl::in_place_type, - absl::string_view(static_cast(value_.data), value_.size)); +StringBytesValue::StringBytesValue(std::string value) + : base_internal::HeapData(kKind), value_(std::move(value)) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); } } // namespace base_internal diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index f257b003a..e8bc3dfaa 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -15,41 +15,47 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_BYTES_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_BYTES_VALUE_H_ +#include #include #include #include +#include "absl/base/attributes.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" +#include "base/types/bytes_type.h" #include "base/value.h" namespace cel { +class MemoryManager; class ValueFactory; class BytesValue : public Value { - protected: - using Rep = base_internal::BytesValueRep; - public: + static constexpr Kind kKind = BytesType::kKind; + static Persistent Empty(ValueFactory& value_factory); // Concat concatenates the contents of two ByteValue, returning a new // ByteValue. The resulting ByteValue is not tied to the lifetime of either of // the input ByteValue. static absl::StatusOr> Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs); + ValueFactory& value_factory, const BytesValue& lhs, + const BytesValue& rhs); - Persistent type() const final; + static bool Is(const Value& value) { return value.kind() == kKind; } - Kind kind() const final { return Kind::kBytes; } + constexpr Kind kind() const { return kKind; } - std::string DebugString() const final; + const Persistent& type() const { return BytesType::Get(); } + + std::string DebugString() const; size_t size() const; @@ -57,75 +63,65 @@ class BytesValue : public Value { bool Equals(absl::string_view bytes) const; bool Equals(const absl::Cord& bytes) const; - bool Equals(const Persistent& bytes) const; + bool Equals(const BytesValue& bytes) const; int Compare(absl::string_view bytes) const; int Compare(const absl::Cord& bytes) const; - int Compare(const Persistent& bytes) const; + int Compare(const BytesValue& bytes) const; std::string ToString() const; - absl::Cord ToCord() const { - // Without the handle we cannot know if this is reference counted. - return ToCord(/*reference_counted=*/false); - } + absl::Cord ToCord() const; + + void HashValue(absl::HashState state) const; + + bool Equals(const Value& other) const; private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + friend class base_internal::PersistentValueHandle; friend class base_internal::InlinedCordBytesValue; friend class base_internal::InlinedStringViewBytesValue; friend class base_internal::StringBytesValue; - friend class base_internal::ExternalDataBytesValue; friend base_internal::BytesValueRep interop_internal::GetBytesValueRep( const Persistent& value); - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kBytes; } - BytesValue() = default; BytesValue(const BytesValue&) = default; BytesValue(BytesValue&&) = default; - - // Get the contents of this BytesValue as absl::Cord. When reference_counted - // is true, the implementation can potentially return an absl::Cord that wraps - // the contents instead of copying. - virtual absl::Cord ToCord(bool reference_counted) const = 0; + BytesValue& operator=(const BytesValue&) = default; + BytesValue& operator=(BytesValue&&) = default; // Get the contents of this BytesValue as either absl::string_view or const // absl::Cord&. - virtual Rep rep() const = 0; - - // See comments for respective member functions on `Value`. - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; + base_internal::BytesValueRep rep() const; }; +CEL_INTERNAL_VALUE_DECL(BytesValue); + namespace base_internal { // Implementation of BytesValue that is stored inlined within a handle. Since // absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. -class InlinedCordBytesValue final : public BytesValue, public ResourceInlined { +class InlinedCordBytesValue final : public BytesValue, + public base_internal::InlineData { private: - template - friend class ValueHandle; + friend class BytesValue; + template + friend class AnyData; - explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | + (static_cast(kKind) << base_internal::kKindShift); - InlinedCordBytesValue() = delete; + explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} InlinedCordBytesValue(const InlinedCordBytesValue&) = default; InlinedCordBytesValue(InlinedCordBytesValue&&) = default; + InlinedCordBytesValue& operator=(const InlinedCordBytesValue&) = default; + InlinedCordBytesValue& operator=(InlinedCordBytesValue&&) = default; - // See comments for respective member functions on `ByteValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Cord value_; }; @@ -134,70 +130,44 @@ class InlinedCordBytesValue final : public BytesValue, public ResourceInlined { // Typically this should only be used for empty strings or data that is static // and lives for the duration of a program. class InlinedStringViewBytesValue final : public BytesValue, - public ResourceInlined { + public base_internal::InlineData { private: - template - friend class ValueHandle; + friend class BytesValue; + template + friend class AnyData; + + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | base_internal::kTriviallyCopyable | + base_internal::kTriviallyDestructible | + (static_cast(kKind) << base_internal::kKindShift); explicit InlinedStringViewBytesValue(absl::string_view value) : value_(value) {} - InlinedStringViewBytesValue() = delete; - InlinedStringViewBytesValue(const InlinedStringViewBytesValue&) = default; InlinedStringViewBytesValue(InlinedStringViewBytesValue&&) = default; + InlinedStringViewBytesValue& operator=(const InlinedStringViewBytesValue&) = + default; + InlinedStringViewBytesValue& operator=(InlinedStringViewBytesValue&&) = + default; - // See comments for respective member functions on `ByteValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::string_view value_; }; // Implementation of BytesValue that uses std::string and is allocated on the // heap, potentially reference counted. -class StringBytesValue final : public BytesValue { +class StringBytesValue final : public BytesValue, + public base_internal::HeapData { private: friend class cel::MemoryManager; + friend class BytesValue; - explicit StringBytesValue(std::string value) : value_(std::move(value)) {} - - StringBytesValue() = delete; - StringBytesValue(const StringBytesValue&) = delete; - StringBytesValue(StringBytesValue&&) = delete; - - // See comments for respective member functions on `ByteValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; + explicit StringBytesValue(std::string value); std::string value_; }; -// Implementation of BytesValue that wraps a contiguous array of bytes and calls -// the releaser when it is no longer needed. It is stored on the heap and -// potentially reference counted. -class ExternalDataBytesValue final : public BytesValue { - private: - friend class cel::MemoryManager; - - explicit ExternalDataBytesValue(ExternalData value) - : value_(std::move(value)) {} - - ExternalDataBytesValue() = delete; - ExternalDataBytesValue(const ExternalDataBytesValue&) = delete; - ExternalDataBytesValue(ExternalDataBytesValue&&) = delete; - - // See comments for respective member functions on `ByteValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - ExternalData value_; -}; - } // namespace base_internal } // namespace cel diff --git a/base/values/double_value.cc b/base/values/double_value.cc index a0fa86baf..89419dc10 100644 --- a/base/values/double_value.cc +++ b/base/values/double_value.cc @@ -16,25 +16,13 @@ #include #include -#include #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "base/types/double_type.h" -#include "internal/casts.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent DoubleValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - DoubleType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(DoubleValue); std::string DoubleValue::DebugString() const { if (std::isfinite(value())) { @@ -63,21 +51,4 @@ std::string DoubleValue::DebugString() const { return "+infinity"; } -void DoubleValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(DoubleValue, *this, address); -} - -void DoubleValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(DoubleValue, *this, address); -} - -bool DoubleValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void DoubleValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/double_value.h b/base/values/double_value.h index c9aa6cb52..e11f900d6 100644 --- a/base/values/double_value.h +++ b/base/values/double_value.h @@ -17,17 +17,21 @@ #include -#include "absl/hash/hash.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/double_type.h" #include "base/value.h" namespace cel { -class ValueFactory; +class DoubleValue final + : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class DoubleValue final : public Value, public base_internal::ResourceInlined { public: + using Base::kKind; + + using Base::Is; + static Persistent NaN(ValueFactory& value_factory); static Persistent PositiveInfinity( @@ -36,40 +40,26 @@ class DoubleValue final : public Value, public base_internal::ResourceInlined { static Persistent NegativeInfinity( ValueFactory& value_factory); - Persistent type() const override; + using Base::kind; - Kind kind() const override { return Kind::kDouble; } + using Base::type; - std::string DebugString() const override; + std::string DebugString() const; - constexpr double value() const { return value_; } - - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + using Base::HashValue; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kDouble; } + using Base::Equals; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit DoubleValue(double value) : value_(value) {} + using Base::value; - DoubleValue() = delete; - - DoubleValue(const DoubleValue&) = default; - DoubleValue(DoubleValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - double value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(DoubleValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(DoubleValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_DOUBLE_VALUE_H_ diff --git a/base/values/duration_value.cc b/base/values/duration_value.cc index 6e648cc29..8c239d0b4 100644 --- a/base/values/duration_value.cc +++ b/base/values/duration_value.cc @@ -15,44 +15,15 @@ #include "base/values/duration_value.h" #include -#include -#include "base/types/duration_type.h" -#include "internal/casts.h" #include "internal/time.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent DurationValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - DurationType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(DurationValue); std::string DurationValue::DebugString() const { return internal::FormatDuration(value()).value(); } -void DurationValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(DurationValue, *this, address); -} - -void DurationValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(DurationValue, *this, address); -} - -bool DurationValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void DurationValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/duration_value.h b/base/values/duration_value.h index 21b5c8381..69900ffc4 100644 --- a/base/values/duration_value.h +++ b/base/values/duration_value.h @@ -17,55 +17,44 @@ #include -#include "absl/hash/hash.h" #include "absl/time/time.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/duration_type.h" #include "base/value.h" namespace cel { -class ValueFactory; +class DurationValue final + : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class DurationValue final : public Value, - public base_internal::ResourceInlined { public: - static Persistent Zero(ValueFactory& value_factory); + using Base::kKind; - Persistent type() const override; + using Base::Is; - Kind kind() const override { return Kind::kDuration; } + static Persistent Zero(ValueFactory& value_factory); - std::string DebugString() const override; + using Base::kind; - constexpr absl::Duration value() const { return value_; } + using Base::type; - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + std::string DebugString() const; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kDuration; } + using Base::HashValue; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit DurationValue(absl::Duration value) : value_(value) {} + using Base::Equals; - DurationValue() = delete; + using Base::value; - DurationValue(const DurationValue&) = default; - DurationValue(DurationValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - absl::Duration value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(DurationValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(DurationValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_DURATION_VALUE_H_ diff --git a/base/values/enum_value.cc b/base/values/enum_value.cc index 2457c5fb9..24c02b6a5 100644 --- a/base/values/enum_value.cc +++ b/base/values/enum_value.cc @@ -17,13 +17,29 @@ #include #include -#include "internal/casts.h" - namespace cel { +CEL_INTERNAL_VALUE_IMPL(EnumValue); + +absl::string_view EnumValue::name() const { + auto constant = type()->FindConstantByNumber(number()); + if (!constant.ok()) { + return absl::string_view(); + } + return constant->name; +} + +std::string EnumValue::DebugString() const { + auto value = name(); + if (value.empty()) { + return absl::StrCat(type()->name(), "(", number(), ")"); + } + return absl::StrCat(type()->name(), ".", value); +} + bool EnumValue::Equals(const Value& other) const { return kind() == other.kind() && type() == other.type() && - number() == internal::down_cast(other).number(); + number() == static_cast(other).number(); } void EnumValue::HashValue(absl::HashState state) const { diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 186c54811..4004dcd43 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -20,96 +20,63 @@ #include #include -#include "absl/base/macros.h" +#include "absl/base/attributes.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" #include "base/types/enum_type.h" #include "base/value.h" -#include "internal/rtti.h" namespace cel { class ValueFactory; // EnumValue represents a single constant belonging to cel::EnumType. -class EnumValue : public Value { +class EnumValue final : public Value, public base_internal::InlineData { public: + static constexpr Kind kKind = EnumType::kKind; + + static bool Is(const Value& value) { return value.kind() == kKind; } + static absl::StatusOr> New( const Persistent& enum_type, ValueFactory& value_factory, EnumType::ConstantId id); - Persistent type() const final { return type_; } - - Kind kind() const final { return Kind::kEnum; } + constexpr Kind kind() const { return kKind; } - virtual int64_t number() const = 0; + constexpr const Persistent& type() const { return type_; } - virtual absl::string_view name() const = 0; + std::string DebugString() const; - protected: - explicit EnumValue(const Persistent& type) : type_(type) { - ABSL_ASSERT(type_); - } + void HashValue(absl::HashState state) const; - private: - friend internal::TypeInfo base_internal::GetEnumValueTypeId( - const EnumValue& enum_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + bool Equals(const Value& other) const; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kEnum; } + constexpr int64_t number() const { return number_; } - EnumValue(const EnumValue&) = delete; - EnumValue(EnumValue&&) = delete; + absl::string_view name() const; - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; + private: + friend class base_internal::PersistentValueHandle; + template + friend class base_internal::AnyData; - std::pair SizeAndAlignment() const override = 0; + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | + (static_cast(kKind) << base_internal::kKindShift); - // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. - virtual internal::TypeInfo TypeId() const = 0; + EnumValue(Persistent type, int64_t number) + : type_(std::move(type)), number_(number) {} + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; Persistent type_; + int64_t number_; }; -// CEL_DECLARE_ENUM_VALUE declares `enum_value` as an enumeration value. It must -// be part of the class definition of `enum_value`. -// -// class MyEnumValue : public cel::EnumValue { -// ... -// private: -// CEL_DECLARE_ENUM_VALUE(MyEnumValue); -// }; -#define CEL_DECLARE_ENUM_VALUE(enum_value) \ - CEL_INTERNAL_DECLARE_VALUE(Enum, enum_value) - -// CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It -// must be called after the class definition of `enum_value`. -// -// class MyEnumValue : public cel::EnumValue { -// ... -// private: -// CEL_DECLARE_ENUM_VALUE(MyEnumValue); -// }; -// -// CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); -#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ - CEL_INTERNAL_IMPLEMENT_VALUE(Enum, enum_value) - -namespace base_internal { - -inline internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value) { - return enum_value.TypeId(); -} - -} // namespace base_internal +CEL_INTERNAL_VALUE_DECL(EnumValue); } // namespace cel diff --git a/base/values/error_value.cc b/base/values/error_value.cc index bad4884f6..998b9c794 100644 --- a/base/values/error_value.cc +++ b/base/values/error_value.cc @@ -14,17 +14,19 @@ #include "base/values/error_value.h" +#include #include #include -#include "base/types/error_type.h" -#include "internal/casts.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" namespace cel { -namespace { +CEL_INTERNAL_VALUE_IMPL(ErrorValue); -using base_internal::PersistentHandleFactory; +namespace { struct StatusPayload final { std::string key; @@ -59,24 +61,11 @@ void StatusHashValue(absl::HashState state, const absl::Status& status) { } // namespace -Persistent ErrorValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - ErrorType::Get()); -} - std::string ErrorValue::DebugString() const { return value().ToString(); } -void ErrorValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(ErrorValue, *this, address); -} - -void ErrorValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(ErrorValue, *this, address); -} - bool ErrorValue::Equals(const Value& other) const { return kind() == other.kind() && - value() == internal::down_cast(other).value(); + value() == static_cast(other).value(); } void ErrorValue::HashValue(absl::HashState state) const { diff --git a/base/values/error_value.h b/base/values/error_value.h index 5b0888c17..4807ddb40 100644 --- a/base/values/error_value.h +++ b/base/values/error_value.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ +#include #include #include @@ -22,46 +23,51 @@ #include "absl/status/status.h" #include "base/kind.h" #include "base/type.h" +#include "base/types/error_type.h" #include "base/value.h" namespace cel { -class ErrorValue final : public Value, public base_internal::ResourceInlined { +class ErrorValue final : public Value, public base_internal::InlineData { public: - Persistent type() const override; + static constexpr Kind kKind = ErrorType::kKind; - Kind kind() const override { return Kind::kError; } + static bool Is(const Value& value) { return value.kind() == kKind; } - std::string DebugString() const override; + constexpr Kind kind() const { return kKind; } - const absl::Status& value() const { return value_; } + const Persistent& type() const { return ErrorType::Get(); } + + std::string DebugString() const; + + void HashValue(absl::HashState state) const; + + bool Equals(const Value& other) const; + + constexpr const absl::Status& value() const { return value_; } private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + friend class PersistentValueHandle; + template + friend class base_internal::AnyData; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kError; } + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | + (static_cast(kKind) << base_internal::kKindShift); - // Called by `base_internal::ValueHandle` to construct value inline. explicit ErrorValue(absl::Status value) : value_(std::move(value)) {} - ErrorValue() = delete; - ErrorValue(const ErrorValue&) = default; ErrorValue(ErrorValue&&) = default; + ErrorValue& operator=(const ErrorValue&) = default; + ErrorValue& operator=(ErrorValue&&) = default; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Status value_; }; +CEL_INTERNAL_VALUE_DECL(ErrorValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_ERROR_VALUE_H_ diff --git a/base/values/int_value.cc b/base/values/int_value.cc index a1588ffec..c7bfcfdf7 100644 --- a/base/values/int_value.cc +++ b/base/values/int_value.cc @@ -15,42 +15,13 @@ #include "base/values/int_value.h" #include -#include #include "absl/strings/str_cat.h" -#include "base/types/int_type.h" -#include "internal/casts.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent IntValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - IntType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(IntValue); std::string IntValue::DebugString() const { return absl::StrCat(value()); } -void IntValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(IntValue, *this, address); -} - -void IntValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(IntValue, *this, address); -} - -bool IntValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void IntValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/int_value.h b/base/values/int_value.h index f2fb17a08..d5a17073e 100644 --- a/base/values/int_value.h +++ b/base/values/int_value.h @@ -18,51 +18,40 @@ #include #include -#include "absl/hash/hash.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/int_type.h" #include "base/value.h" namespace cel { -class IntValue; +class IntValue final : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class IntValue final : public Value, public base_internal::ResourceInlined { public: - Persistent type() const override; - - Kind kind() const override { return Kind::kInt; } + using Base::kKind; - std::string DebugString() const override; + using Base::Is; - constexpr int64_t value() const { return value_; } + using Base::kind; - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + using Base::type; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kInt; } + std::string DebugString() const; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit IntValue(int64_t value) : value_(value) {} + using Base::HashValue; - IntValue() = delete; + using Base::Equals; - IntValue(const IntValue&) = default; - IntValue(IntValue&&) = default; + using Base::value; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - int64_t value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(IntValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(IntValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_INT_VALUE_H_ diff --git a/base/values/list_value.cc b/base/values/list_value.cc index 4af7217e0..aef7d8671 100644 --- a/base/values/list_value.cc +++ b/base/values/list_value.cc @@ -14,8 +14,20 @@ #include "base/values/list_value.h" +#include + +#include "absl/base/macros.h" + namespace cel { -// +CEL_INTERNAL_VALUE_IMPL(ListValue); + +ListValue::ListValue(Persistent type) + : base_internal::HeapData(kKind), type_(std::move(type)) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} } // namespace cel diff --git a/base/values/list_value.h b/base/values/list_value.h index abe5b630f..d0964b9ab 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -23,6 +23,7 @@ #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" #include "base/types/list_type.h" @@ -34,13 +35,19 @@ namespace cel { class ValueFactory; // ListValue represents an instance of cel::ListType. -class ListValue : public Value { +class ListValue : public Value, public base_internal::HeapData { public: + static constexpr Kind kKind = ListType::kKind; + + static bool Is(const Value& value) { return value.kind() == kKind; } + // TODO(issues/5): implement iterators so we can have cheap concated lists - Persistent type() const final { return type_; } + constexpr const Persistent& type() const { return type_; } + + constexpr Kind kind() const { return kKind; } - Kind kind() const final { return Kind::kList; } + virtual std::string DebugString() const = 0; virtual size_t size() const = 0; @@ -49,31 +56,17 @@ class ListValue : public Value { virtual absl::StatusOr> Get( ValueFactory& value_factory, size_t index) const = 0; + virtual bool Equals(const Value& other) const = 0; + + virtual void HashValue(absl::HashState state) const = 0; + protected: - explicit ListValue(const Persistent& type) : type_(type) {} + explicit ListValue(Persistent type); private: friend internal::TypeInfo base_internal::GetListValueTypeId( const ListValue& list_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kList; } - - ListValue(const ListValue&) = delete; - ListValue(ListValue&&) = delete; - - // TODO(issues/5): I do not like this, we should have these two take a - // ValueFactory and return absl::StatusOr and absl::Status. We support - // lazily created values, so errors can occur during equality testing. - // Especially if there are different value implementations for the same type. - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; + friend class base_internal::PersistentValueHandle; // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; @@ -81,6 +74,8 @@ class ListValue : public Value { const Persistent type_; }; +CEL_INTERNAL_VALUE_DECL(ListValue); + // CEL_DECLARE_LIST_VALUE declares `list_value` as an list value. It must // be part of the class definition of `list_value`. // diff --git a/base/values/map_value.cc b/base/values/map_value.cc index 37c6405d0..fb281b1d6 100644 --- a/base/values/map_value.cc +++ b/base/values/map_value.cc @@ -14,8 +14,20 @@ #include "base/values/map_value.h" +#include + +#include "absl/base/macros.h" + namespace cel { -// +CEL_INTERNAL_VALUE_IMPL(MapValue); + +MapValue::MapValue(Persistent type) + : base_internal::HeapData(kKind), type_(std::move(type)) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} } // namespace cel diff --git a/base/values/map_value.h b/base/values/map_value.h index db907feac..d7f5f32ac 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_H_ +#include #include #include #include @@ -23,6 +24,7 @@ #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" #include "base/types/map_type.h" @@ -34,16 +36,26 @@ namespace cel { class ValueFactory; // MapValue represents an instance of cel::MapType. -class MapValue : public Value { +class MapValue : public Value, public base_internal::HeapData { public: - Persistent type() const final { return type_; } + static constexpr Kind kKind = MapType::kKind; - Kind kind() const final { return Kind::kMap; } + static bool Is(const Value& value) { return value.kind() == kKind; } + + constexpr Kind kind() const { return kKind; } + + constexpr const Persistent& type() const { return type_; } + + virtual std::string DebugString() const = 0; virtual size_t size() const = 0; virtual bool empty() const { return size() == 0; } + virtual bool Equals(const Value& other) const = 0; + + virtual void HashValue(absl::HashState state) const = 0; + virtual absl::StatusOr> Get( ValueFactory& value_factory, const Persistent& key) const = 0; @@ -52,34 +64,21 @@ class MapValue : public Value { const Persistent& key) const = 0; protected: - explicit MapValue(const Persistent& type) : type_(type) {} + explicit MapValue(Persistent type); private: friend internal::TypeInfo base_internal::GetMapValueTypeId( const MapValue& map_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kMap; } - - MapValue(const MapValue&) = delete; - MapValue(MapValue&&) = delete; - - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; + friend class base_internal::PersistentValueHandle; // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; - // Set lazily, by EnumValue::New. - Persistent type_; + const Persistent type_; }; +CEL_INTERNAL_VALUE_DECL(MapValue); + // CEL_DECLARE_MAP_VALUE declares `map_value` as an map value. It must // be part of the class definition of `map_value`. // diff --git a/base/values/null_value.cc b/base/values/null_value.cc index 87f90d368..db653553f 100644 --- a/base/values/null_value.cc +++ b/base/values/null_value.cc @@ -15,45 +15,11 @@ #include "base/values/null_value.h" #include -#include - -#include "base/types/null_type.h" -#include "internal/no_destructor.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent NullValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - NullType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(NullValue); std::string NullValue::DebugString() const { return "null"; } -const NullValue& NullValue::Get() { - static const internal::NoDestructor instance; - return *instance; -} - -void NullValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(NullValue, *this, address); -} - -void NullValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(NullValue, *this, address); -} - -bool NullValue::Equals(const Value& other) const { - return kind() == other.kind(); -} - -void NullValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), 0); -} - } // namespace cel diff --git a/base/values/null_value.h b/base/values/null_value.h index 53098c059..0c5be829c 100644 --- a/base/values/null_value.h +++ b/base/values/null_value.h @@ -17,53 +17,38 @@ #include -#include "absl/hash/hash.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/null_type.h" #include "base/value.h" namespace cel { -class ValueFactory; +class NullValue final : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class NullValue final : public Value, public base_internal::ResourceInlined { public: - static Persistent Get(ValueFactory& value_factory); + using Base::kKind; - Persistent type() const override; + using Base::Is; - Kind kind() const override { return Kind::kNullType; } + static Persistent Get(ValueFactory& value_factory); - std::string DebugString() const override; + using Base::kind; - // Note GCC does not consider a friend member as a member of a friend. - ABSL_ATTRIBUTE_PURE_FUNCTION static const NullValue& Get(); + using Base::type; - bool Equals(const Value& other) const override; + std::string DebugString() const; - void HashValue(absl::HashState state) const override; + using Base::HashValue; + + using Base::Equals; private: - friend class ValueFactory; - template - friend class internal::NoDestructor; - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; - - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kNullType; } - - NullValue() = default; - NullValue(const NullValue&) = default; - NullValue(NullValue&&) = default; - - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(NullValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(NullValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_NULL_VALUE_H_ diff --git a/base/values/string_value.cc b/base/values/string_value.cc index 77fb5c437..e4795492a 100644 --- a/base/values/string_value.cc +++ b/base/values/string_value.cc @@ -17,16 +17,16 @@ #include #include +#include "absl/base/macros.h" #include "base/types/string_type.h" -#include "internal/casts.h" #include "internal/strings.h" #include "internal/utf8.h" namespace cel { -namespace { +CEL_INTERNAL_VALUE_IMPL(StringValue); -using base_internal::PersistentHandleFactory; +namespace { struct StringValueDebugStringVisitor final { std::string operator()(absl::string_view value) const { @@ -174,22 +174,8 @@ class HashValueVisitor final { } // namespace -Persistent StringValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - StringType::Get()); -} - size_t StringValue::size() const { - // We lazily calculate the code point count in some circumstances. If the code - // point count is 0 and the underlying rep is not empty we need to actually - // calculate the size. It is okay if this is done by multiple threads - // simultaneously, it is a benign race. - size_t size = size_.load(std::memory_order_relaxed); - if (size == 0 && !empty()) { - size = absl::visit(StringValueSizeVisitor{}, rep()); - size_.store(size, std::memory_order_relaxed); - } - return size; + return absl::visit(StringValueSizeVisitor{}, rep()); } bool StringValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } @@ -202,8 +188,8 @@ bool StringValue::Equals(const absl::Cord& string) const { return absl::visit(EqualsVisitor(string), rep()); } -bool StringValue::Equals(const Persistent& string) const { - return absl::visit(EqualsVisitor(*this), string->rep()); +bool StringValue::Equals(const StringValue& string) const { + return absl::visit(EqualsVisitor(*this), string.rep()); } int StringValue::Compare(absl::string_view string) const { @@ -214,14 +200,44 @@ int StringValue::Compare(const absl::Cord& string) const { return absl::visit(CompareVisitor(string), rep()); } -int StringValue::Compare(const Persistent& string) const { - return absl::visit(CompareVisitor(*this), string->rep()); +int StringValue::Compare(const StringValue& string) const { + return absl::visit(CompareVisitor(*this), string.rep()); } std::string StringValue::ToString() const { return absl::visit(ToStringVisitor{}, rep()); } +absl::Cord StringValue::ToCord() const { + switch (base_internal::Metadata::For(this)->locality()) { + case base_internal::DataLocality::kNull: + return absl::Cord(); + case base_internal::DataLocality::kStoredInline: + if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + return absl::MakeCordFromExternal( + static_cast( + this) + ->value_, + []() {}); + } else { + return static_cast(this) + ->value_; + } + case base_internal::DataLocality::kReferenceCounted: + base_internal::Metadata::For(this)->Ref(); + return absl::MakeCordFromExternal( + static_cast(this)->value_, + [this]() { + if (base_internal::Metadata::For(this)->Unref()) { + delete static_cast(this); + } + }); + case base_internal::DataLocality::kArenaAllocated: + return absl::Cord( + static_cast(this)->value_); + } +} + std::string StringValue::DebugString() const { return absl::visit(StringValueDebugStringVisitor{}, rep()); } @@ -229,7 +245,7 @@ std::string StringValue::DebugString() const { bool StringValue::Equals(const Value& other) const { return kind() == other.kind() && absl::visit(EqualsVisitor(*this), - internal::down_cast(other).rep()); + static_cast(other).rep()); } void StringValue::HashValue(absl::HashState state) const { @@ -238,81 +254,43 @@ void StringValue::HashValue(absl::HashState state) const { rep()); } -namespace base_internal { - -absl::Cord InlinedCordStringValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return value_; -} - -void InlinedCordStringValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(InlinedCordStringValue, *this, address); -} - -void InlinedCordStringValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(InlinedCordStringValue, *this, address); -} - -typename InlinedCordStringValue::Rep InlinedCordStringValue::rep() const { - return Rep(absl::in_place_type>, - std::cref(value_)); -} - -absl::Cord InlinedStringViewStringValue::ToCord(bool reference_counted) const { - static_cast(reference_counted); - return absl::Cord(value_); -} - -void InlinedStringViewStringValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(InlinedStringViewStringValue, *this, address); -} - -void InlinedStringViewStringValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(InlinedStringViewStringValue, *this, address); -} - -typename InlinedStringViewStringValue::Rep InlinedStringViewStringValue::rep() - const { - return Rep(absl::in_place_type, value_); -} - -std::pair StringStringValue::SizeAndAlignment() const { - return std::make_pair(sizeof(StringStringValue), alignof(StringStringValue)); -} - -absl::Cord StringStringValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal(absl::string_view(value_), - [this]() { Unref(); }); +base_internal::StringValueRep StringValue::rep() const { + switch (base_internal::Metadata::For(this)->locality()) { + case base_internal::DataLocality::kNull: + return base_internal::StringValueRep(); + case base_internal::DataLocality::kStoredInline: + if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + return base_internal::StringValueRep( + absl::in_place_type, + static_cast( + this) + ->value_); + } else { + return base_internal::StringValueRep( + absl::in_place_type>, + std::cref( + static_cast(this) + ->value_)); + } + case base_internal::DataLocality::kReferenceCounted: + ABSL_FALLTHROUGH_INTENDED; + case base_internal::DataLocality::kArenaAllocated: + return base_internal::StringValueRep( + absl::in_place_type, + absl::string_view( + static_cast(this) + ->value_)); } - return absl::Cord(value_); -} - -typename StringStringValue::Rep StringStringValue::rep() const { - return Rep(absl::in_place_type, absl::string_view(value_)); -} - -std::pair ExternalDataStringValue::SizeAndAlignment() const { - return std::make_pair(sizeof(ExternalDataStringValue), - alignof(ExternalDataStringValue)); } -absl::Cord ExternalDataStringValue::ToCord(bool reference_counted) const { - if (reference_counted) { - Ref(); - return absl::MakeCordFromExternal( - absl::string_view(static_cast(value_.data), value_.size), - [this]() { Unref(); }); - } - return absl::Cord( - absl::string_view(static_cast(value_.data), value_.size)); -} +namespace base_internal { -typename ExternalDataStringValue::Rep ExternalDataStringValue::rep() const { - return Rep( - absl::in_place_type, - absl::string_view(static_cast(value_.data), value_.size)); +StringStringValue::StringStringValue(std::string value) + : base_internal::HeapData(kKind), value_(std::move(value)) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); } } // namespace base_internal diff --git a/base/values/string_value.h b/base/values/string_value.h index edd07381b..abb85e2aa 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -15,38 +15,47 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_STRING_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_STRING_VALUE_H_ +#include #include #include #include +#include "absl/base/attributes.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" +#include "base/types/string_type.h" #include "base/value.h" namespace cel { +class MemoryManager; class ValueFactory; class StringValue : public Value { - protected: - using Rep = base_internal::StringValueRep; - public: + static constexpr Kind kKind = StringType::kKind; + static Persistent Empty(ValueFactory& value_factory); + // Concat concatenates the contents of two ByteValue, returning a new + // ByteValue. The resulting ByteValue is not tied to the lifetime of either of + // the input ByteValue. static absl::StatusOr> Concat( - ValueFactory& value_factory, const Persistent& lhs, - const Persistent& rhs); + ValueFactory& value_factory, const StringValue& lhs, + const StringValue& rhs); + + static bool Is(const Value& value) { return value.kind() == kKind; } - Persistent type() const final; + constexpr Kind kind() const { return kKind; } - Kind kind() const final { return Kind::kString; } + const Persistent& type() const { return StringType::Get(); } - std::string DebugString() const final; + std::string DebugString() const; size_t size() const; @@ -54,89 +63,67 @@ class StringValue : public Value { bool Equals(absl::string_view string) const; bool Equals(const absl::Cord& string) const; - bool Equals(const Persistent& string) const; + bool Equals(const StringValue& string) const; int Compare(absl::string_view string) const; int Compare(const absl::Cord& string) const; - int Compare(const Persistent& string) const; + int Compare(const StringValue& string) const; std::string ToString() const; - absl::Cord ToCord() const { - // Without the handle we cannot know if this is reference counted. - return ToCord(/*reference_counted=*/false); - } + absl::Cord ToCord() const; + + void HashValue(absl::HashState state) const; + + bool Equals(const Value& other) const; private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + friend class base_internal::PersistentValueHandle; friend class base_internal::InlinedCordStringValue; friend class base_internal::InlinedStringViewStringValue; friend class base_internal::StringStringValue; - friend class base_internal::ExternalDataStringValue; friend base_internal::StringValueRep interop_internal::GetStringValueRep( const Persistent& value); - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kString; } - - explicit StringValue(size_t size) : size_(size) {} - StringValue() = default; - - StringValue(const StringValue& other) - : StringValue(other.size_.load(std::memory_order_relaxed)) {} - - StringValue(StringValue&& other) - : StringValue(other.size_.exchange(0, std::memory_order_relaxed)) {} - - // Get the contents of this BytesValue as absl::Cord. When reference_counted - // is true, the implementation can potentially return an absl::Cord that wraps - // the contents instead of copying. - virtual absl::Cord ToCord(bool reference_counted) const = 0; + StringValue(const StringValue&) = default; + StringValue(StringValue&&) = default; + StringValue& operator=(const StringValue&) = default; + StringValue& operator=(StringValue&&) = default; // Get the contents of this StringValue as either absl::string_view or const // absl::Cord&. - virtual Rep rep() const = 0; - - // See comments for respective member functions on `Value`. - bool Equals(const Value& other) const final; - void HashValue(absl::HashState state) const final; - - // Lazily cached code point count. - mutable std::atomic size_ = 0; + base_internal::StringValueRep rep() const; }; +CEL_INTERNAL_VALUE_DECL(StringValue); + namespace base_internal { // Implementation of StringValue that is stored inlined within a handle. Since -// absl::Cord is reference counted itself, this is more efficient then storing +// absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. class InlinedCordStringValue final : public StringValue, - public ResourceInlined { + public base_internal::InlineData { private: - template - friend class ValueHandle; + friend class StringValue; + friend class ValueFactory; + template + friend class AnyData; - explicit InlinedCordStringValue(absl::Cord value) - : InlinedCordStringValue(0, std::move(value)) {} - - InlinedCordStringValue(size_t size, absl::Cord value) - : StringValue(size), value_(std::move(value)) {} + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | + (static_cast(kKind) << base_internal::kKindShift); - InlinedCordStringValue() = delete; + explicit InlinedCordStringValue(absl::Cord value) + : value_(std::move(value)) {} InlinedCordStringValue(const InlinedCordStringValue&) = default; InlinedCordStringValue(InlinedCordStringValue&&) = default; + InlinedCordStringValue& operator=(const InlinedCordStringValue&) = default; + InlinedCordStringValue& operator=(InlinedCordStringValue&&) = default; - // See comments for respective member functions on `StringValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Cord value_; }; @@ -145,80 +132,46 @@ class InlinedCordStringValue final : public StringValue, // Typically this should only be used for empty strings or data that is static // and lives for the duration of a program. class InlinedStringViewStringValue final : public StringValue, - public ResourceInlined { + public base_internal::InlineData { private: - template - friend class ValueHandle; + friend class StringValue; + friend class ValueFactory; + template + friend class AnyData; - explicit InlinedStringViewStringValue(absl::string_view value) - : InlinedStringViewStringValue(0, value) {} - - InlinedStringViewStringValue(size_t size, absl::string_view value) - : StringValue(size), value_(value) {} + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | base_internal::kTriviallyCopyable | + base_internal::kTriviallyDestructible | + (static_cast(kKind) << base_internal::kKindShift); - InlinedStringViewStringValue() = delete; + explicit InlinedStringViewStringValue(absl::string_view value) + : value_(value) {} InlinedStringViewStringValue(const InlinedStringViewStringValue&) = default; InlinedStringViewStringValue(InlinedStringViewStringValue&&) = default; + InlinedStringViewStringValue& operator=(const InlinedStringViewStringValue&) = + default; + InlinedStringViewStringValue& operator=(InlinedStringViewStringValue&&) = + default; - // See comments for respective member functions on `StringValue` and `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::string_view value_; }; // Implementation of StringValue that uses std::string and is allocated on the // heap, potentially reference counted. -class StringStringValue final : public StringValue { +class StringStringValue final : public StringValue, + public base_internal::HeapData { private: friend class cel::MemoryManager; + friend class StringValue; + friend class ValueFactory; - explicit StringStringValue(std::string value) - : StringStringValue(0, std::move(value)) {} - - StringStringValue(size_t size, std::string value) - : StringValue(size), value_(std::move(value)) {} - - StringStringValue() = delete; - StringStringValue(const StringStringValue&) = delete; - StringStringValue(StringStringValue&&) = delete; - - // See comments for respective member functions on `StringValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; + explicit StringStringValue(std::string value); std::string value_; }; -// Implementation of StringValue that wraps a contiguous array of bytes and -// calls the releaser when it is no longer needed. It is stored on the heap and -// potentially reference counted. -class ExternalDataStringValue final : public StringValue { - private: - friend class cel::MemoryManager; - - explicit ExternalDataStringValue(ExternalData value) - : ExternalDataStringValue(0, std::move(value)) {} - - ExternalDataStringValue(size_t size, ExternalData value) - : StringValue(size), value_(std::move(value)) {} - - ExternalDataStringValue() = delete; - ExternalDataStringValue(const ExternalDataStringValue&) = delete; - ExternalDataStringValue(ExternalDataStringValue&&) = delete; - - // See comments for respective member functions on `StringValue` and `Value`. - std::pair SizeAndAlignment() const override; - absl::Cord ToCord(bool reference_counted) const override; - Rep rep() const override; - - ExternalData value_; -}; - } // namespace base_internal } // namespace cel diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc index 7ca2e21b5..c2931a262 100644 --- a/base/values/struct_value.cc +++ b/base/values/struct_value.cc @@ -17,10 +17,21 @@ #include #include +#include "absl/base/macros.h" #include "base/types/struct_type.h" namespace cel { +CEL_INTERNAL_VALUE_IMPL(StructValue); + +StructValue::StructValue(Persistent type) + : base_internal::HeapData(kKind), type_(std::move(type)) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + struct StructValue::SetFieldVisitor final { StructValue& struct_value; const Persistent& value; diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 8afc3c04e..0e63fdc1c 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -20,7 +20,6 @@ #include #include -#include "absl/base/macros.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -35,17 +34,27 @@ namespace cel { class ValueFactory; // StructValue represents an instance of cel::StructType. -class StructValue : public Value { +class StructValue : public Value, public base_internal::HeapData { public: + static constexpr Kind kKind = Kind::kStruct; + + static bool Is(const Value& value) { return value.kind() == kKind; } + using FieldId = StructType::FieldId; static absl::StatusOr> New( const Persistent& struct_type, ValueFactory& value_factory); - Persistent type() const final { return type_; } + constexpr Kind kind() const { return kKind; } + + constexpr const Persistent& type() const { return type_; } + + virtual std::string DebugString() const = 0; - Kind kind() const final { return Kind::kStruct; } + virtual void HashValue(absl::HashState state) const = 0; + + virtual bool Equals(const Value& other) const = 0; absl::Status SetField(FieldId field, const Persistent& value); @@ -55,9 +64,7 @@ class StructValue : public Value { absl::StatusOr HasField(FieldId field) const; protected: - explicit StructValue(const Persistent& type) : type_(type) { - ABSL_ASSERT(type_); - } + explicit StructValue(Persistent type); virtual absl::Status SetFieldByName(absl::string_view name, const Persistent& value) = 0; @@ -85,28 +92,22 @@ class StructValue : public Value { friend struct HasFieldVisitor; friend internal::TypeInfo base_internal::GetStructValueTypeId( const StructValue& struct_value); - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + friend class base_internal::PersistentValueHandle; // Called by base_internal::ValueHandleBase to implement Is for Transient and // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kStruct; } StructValue(const StructValue&) = delete; StructValue(StructValue&&) = delete; - bool Equals(const Value& other) const override = 0; - void HashValue(absl::HashState state) const override = 0; - - std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_STRUCT_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; - Persistent type_; + const Persistent type_; }; +CEL_INTERNAL_VALUE_DECL(StructValue); + // CEL_DECLARE_STRUCT_VALUE declares `struct_value` as an struct value. It must // be part of the class definition of `struct_value`. // diff --git a/base/values/timestamp_value.cc b/base/values/timestamp_value.cc index 2cc2079c5..b573df3b0 100644 --- a/base/values/timestamp_value.cc +++ b/base/values/timestamp_value.cc @@ -15,44 +15,15 @@ #include "base/values/timestamp_value.h" #include -#include -#include "base/types/timestamp_type.h" -#include "internal/casts.h" #include "internal/time.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent TimestampValue::type() const { - return PersistentHandleFactory::MakeUnmanaged< - const TimestampType>(TimestampType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(TimestampValue); std::string TimestampValue::DebugString() const { return internal::FormatTimestamp(value()).value(); } -void TimestampValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(TimestampValue, *this, address); -} - -void TimestampValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(TimestampValue, *this, address); -} - -bool TimestampValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void TimestampValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/timestamp_value.h b/base/values/timestamp_value.h index 9e22c1a7c..c22c83d9e 100644 --- a/base/values/timestamp_value.h +++ b/base/values/timestamp_value.h @@ -17,58 +17,45 @@ #include -#include "absl/hash/hash.h" #include "absl/time/time.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/timestamp_type.h" #include "base/value.h" namespace cel { -class ValueFactory; +class TimestampValue final + : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; -class TimestampValue final : public Value, - public base_internal::ResourceInlined { public: - static Persistent UnixEpoch( - ValueFactory& value_factory); + using Base::kKind; - Persistent type() const override; + using Base::Is; - Kind kind() const override { return Kind::kTimestamp; } - - std::string DebugString() const override; + static Persistent UnixEpoch( + ValueFactory& value_factory); - constexpr absl::Time value() const { return value_; } + using Base::kind; - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + using Base::type; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { - return value.kind() == Kind::kTimestamp; - } + std::string DebugString() const; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit TimestampValue(absl::Time value) : value_(value) {} + using Base::HashValue; - TimestampValue() = delete; + using Base::Equals; - TimestampValue(const TimestampValue&) = default; - TimestampValue(TimestampValue&&) = default; + using Base::value; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - absl::Time value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(TimestampValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(TimestampValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_TIMESTAMP_VALUE_H_ diff --git a/base/values/type_value.cc b/base/values/type_value.cc index 748738c7a..01e2ad9d2 100644 --- a/base/values/type_value.cc +++ b/base/values/type_value.cc @@ -17,35 +17,15 @@ #include #include -#include "base/types/type_type.h" -#include "internal/casts.h" - namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent TypeValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - TypeType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(TypeValue); std::string TypeValue::DebugString() const { return value()->DebugString(); } -void TypeValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(TypeValue, *this, address); -} - -void TypeValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(TypeValue, *this, address); -} - bool TypeValue::Equals(const Value& other) const { return kind() == other.kind() && - value() == internal::down_cast(other).value(); + value() == static_cast(other).value(); } void TypeValue::HashValue(absl::HashState state) const { diff --git a/base/values/type_value.h b/base/values/type_value.h index 0635a7988..a1b3c3ed6 100644 --- a/base/values/type_value.h +++ b/base/values/type_value.h @@ -15,53 +15,58 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ +#include #include #include #include "absl/hash/hash.h" #include "base/kind.h" #include "base/type.h" +#include "base/types/type_type.h" #include "base/value.h" namespace cel { -// TypeValue represents an instance of cel::Type. -class TypeValue final : public Value, base_internal::ResourceInlined { +class TypeValue final : public Value, public base_internal::InlineData { public: - Persistent type() const override; + static constexpr Kind kKind = TypeType::kKind; - Kind kind() const override { return Kind::kType; } + static bool Is(const Value& value) { return value.kind() == kKind; } - std::string DebugString() const override; + constexpr Kind kind() const { return kKind; } - Persistent value() const { return value_; } + const Persistent& type() const { return TypeType::Get(); } - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + std::string DebugString() const; + + void HashValue(absl::HashState state) const; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kType; } + bool Equals(const Value& other) const; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit TypeValue(Persistent type) : value_(std::move(type)) {} + constexpr const Persistent& value() const { return value_; } + + private: + friend class PersistentValueHandle; + template + friend class base_internal::AnyData; - TypeValue() = delete; + static constexpr uintptr_t kVirtualPointer = + base_internal::kStoredInline | + (static_cast(kKind) << base_internal::kKindShift); + + explicit TypeValue(Persistent value) : value_(std::move(value)) {} TypeValue(const TypeValue&) = default; TypeValue(TypeValue&&) = default; + TypeValue& operator=(const TypeValue&) = default; + TypeValue& operator=(TypeValue&&) = default; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; - + uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; Persistent value_; }; +CEL_INTERNAL_VALUE_DECL(TypeValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_TYPE_VALUE_H_ diff --git a/base/values/uint_value.cc b/base/values/uint_value.cc index a8a0d143f..650aaa259 100644 --- a/base/values/uint_value.cc +++ b/base/values/uint_value.cc @@ -15,44 +15,15 @@ #include "base/values/uint_value.h" #include -#include #include "absl/strings/str_cat.h" -#include "base/types/uint_type.h" -#include "internal/casts.h" namespace cel { -namespace { - -using base_internal::PersistentHandleFactory; - -} - -Persistent UintValue::type() const { - return PersistentHandleFactory::MakeUnmanaged( - UintType::Get()); -} +CEL_INTERNAL_VALUE_IMPL(UintValue); std::string UintValue::DebugString() const { return absl::StrCat(value(), "u"); } -void UintValue::CopyTo(Value& address) const { - CEL_INTERNAL_VALUE_COPY_TO(UintValue, *this, address); -} - -void UintValue::MoveTo(Value& address) { - CEL_INTERNAL_VALUE_MOVE_TO(UintValue, *this, address); -} - -bool UintValue::Equals(const Value& other) const { - return kind() == other.kind() && - value() == internal::down_cast(other).value(); -} - -void UintValue::HashValue(absl::HashState state) const { - absl::HashState::combine(std::move(state), type(), value()); -} - } // namespace cel diff --git a/base/values/uint_value.h b/base/values/uint_value.h index d76c44d06..383665c64 100644 --- a/base/values/uint_value.h +++ b/base/values/uint_value.h @@ -18,49 +18,40 @@ #include #include -#include "absl/hash/hash.h" -#include "base/kind.h" -#include "base/type.h" +#include "base/types/uint_type.h" #include "base/value.h" namespace cel { -class UintValue final : public Value, public base_internal::ResourceInlined { - public: - Persistent type() const override; +class UintValue final : public base_internal::SimpleValue { + private: + using Base = base_internal::SimpleValue; - Kind kind() const override { return Kind::kUint; } + public: + using Base::kKind; - std::string DebugString() const override; + using Base::Is; - constexpr uint64_t value() const { return value_; } + using Base::kind; - private: - template - friend class base_internal::ValueHandle; - friend class base_internal::ValueHandleBase; + using Base::type; - // Called by base_internal::ValueHandleBase to implement Is for Transient and - // Persistent. - static bool Is(const Value& value) { return value.kind() == Kind::kUint; } + std::string DebugString() const; - // Called by `base_internal::ValueHandle` to construct value inline. - explicit UintValue(uint64_t value) : value_(value) {} + using Base::HashValue; - UintValue() = delete; + using Base::Equals; - UintValue(const UintValue&) = default; - UintValue(UintValue&&) = default; + using Base::value; - // See comments for respective member functions on `Value`. - void CopyTo(Value& address) const override; - void MoveTo(Value& address) override; - bool Equals(const Value& other) const override; - void HashValue(absl::HashState state) const override; + private: + using Base::Base; - uint64_t value_; + CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(UintValue); }; +CEL_INTERNAL_SIMPLE_VALUE_STANDALONES(UintValue); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_UINT_VALUE_H_ diff --git a/extensions/protobuf/memory_manager.cc b/extensions/protobuf/memory_manager.cc index 7e8d92eb8..77e168ca3 100644 --- a/extensions/protobuf/memory_manager.cc +++ b/extensions/protobuf/memory_manager.cc @@ -15,37 +15,15 @@ #include "extensions/protobuf/memory_manager.h" #include -#include #include "absl/base/macros.h" #include "absl/base/optimization.h" namespace cel::extensions { -MemoryManager::AllocationResult ProtoMemoryManager::Allocate( - size_t size, size_t align) { - void* pointer; - if (arena_ != nullptr) { - pointer = arena_->AllocateAligned(size, align); - } else { - if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { - pointer = ::operator new(size, std::nothrow); - } else { - pointer = ::operator new(size, static_cast(align), - std::nothrow); - } - } - return {pointer}; -} - -void ProtoMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { - // Only possible when `arena_` is nullptr. - ABSL_HARDENING_ASSERT(arena_ == nullptr); - if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { - ::operator delete(pointer, size); - } else { - ::operator delete(pointer, size, static_cast(align)); - } +void* ProtoMemoryManager::Allocate(size_t size, size_t align) { + ABSL_HARDENING_ASSERT(arena_ != nullptr); + return arena_->AllocateAligned(size, align); } void ProtoMemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h index 4d515140c..e9c09c97c 100644 --- a/extensions/protobuf/memory_manager.h +++ b/extensions/protobuf/memory_manager.h @@ -58,9 +58,7 @@ class ProtoMemoryManager final : public ArenaMemoryManager { } private: - AllocationResult Allocate(size_t size, size_t align) override; - - void Deallocate(void* pointer, size_t size, size_t align) override; + void* Allocate(size_t size, size_t align) override; void OwnDestructor(void* pointer, void (*destruct)(void*)) override; From ee2261f8a6b0f27ac7a84c0b979ae47eb6a75cf2 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 24 Jun 2022 14:59:07 +0000 Subject: [PATCH 008/303] Internal change PiperOrigin-RevId: 457012079 --- base/BUILD | 2 + base/internal/BUILD | 2 + base/internal/data.h | 234 ++++++++++++++++++++------------- base/internal/managed_memory.h | 7 +- base/memory_manager.h | 5 +- base/type.cc | 9 +- base/type.h | 11 +- base/value.cc | 17 +-- base/value.h | 15 +-- base/values/bytes_value.cc | 12 +- base/values/bytes_value.h | 11 +- base/values/enum_value.h | 7 +- base/values/error_value.h | 6 +- base/values/string_value.cc | 12 +- base/values/string_value.h | 10 +- base/values/type_value.h | 6 +- internal/BUILD | 27 ++++ internal/assume_aligned.h | 40 ++++++ internal/launder.h | 43 ++++++ internal/unreachable.h | 41 ++++++ 20 files changed, 359 insertions(+), 158 deletions(-) create mode 100644 internal/assume_aligned.h create mode 100644 internal/launder.h create mode 100644 internal/unreachable.h diff --git a/base/BUILD b/base/BUILD index 33f2f8f1d..9a4f65ba1 100644 --- a/base/BUILD +++ b/base/BUILD @@ -123,6 +123,7 @@ cc_library( "//base/internal:type", "//internal:casts", "//internal:rtti", + "//internal:unreachable", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", @@ -227,6 +228,7 @@ cc_library( "//internal:rtti", "//internal:strings", "//internal:time", + "//internal:unreachable", "//internal:utf8", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", diff --git a/base/internal/BUILD b/base/internal/BUILD index 77759e1e9..77c25b4de 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -21,6 +21,8 @@ cc_library( hdrs = ["data.h"], deps = [ "//base:kind", + "//internal:assume_aligned", + "//internal:launder", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", ], diff --git a/base/internal/data.h b/base/internal/data.h index 73f53e0b7..188e4bbf7 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -19,12 +19,16 @@ #include #include #include +#include +#include #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/numeric/bits.h" #include "base/kind.h" +#include "internal/assume_aligned.h" +#include "internal/launder.h" namespace cel::base_internal { @@ -98,16 +102,33 @@ class Data {}; // // For inline data, Kind is stored in the most significant byte of `metadata`. class InlineData /* : public Data */ { - // uintptr_t metadata - public: static void* operator new(size_t) = delete; static void* operator new[](size_t) = delete; static void operator delete(void*) = delete; static void operator delete[](void*) = delete; + + InlineData(const InlineData&) = default; + InlineData(InlineData&&) = default; + + InlineData& operator=(const InlineData&) = default; + InlineData& operator=(InlineData&&) = default; + + protected: + constexpr explicit InlineData(uintptr_t metadata) : metadata_(metadata) {} + + private: + uintptr_t metadata_ ABSL_ATTRIBUTE_UNUSED = 0; }; +static_assert(std::is_trivially_copyable_v, + "InlineData must be trivially copyable"); +static_assert(std::is_trivially_destructible_v, + "InlineData must be trivially destructible"); +static_assert(sizeof(InlineData) == sizeof(uintptr_t), + "InlineData has unexpected padding"); + // Used purely for a static_assert. constexpr size_t HeapDataMetadataAndReferenceCountOffset(); @@ -120,9 +141,6 @@ constexpr size_t HeapDataMetadataAndReferenceCountOffset(); // with twos complement integers, allows us to easily detect incorrect reference // counting as the reference count will be negative. class HeapData /* : public Data */ { - // uintptr_t vptr - // std::atomic metadata_and_reference_count - public: HeapData(const HeapData&) = delete; HeapData(HeapData&&) = delete; @@ -156,115 +174,118 @@ static_assert(sizeof(HeapData) == sizeof(uintptr_t) * 2, // Provides introspection for `Data`. class Metadata final { public: - ABSL_ATTRIBUTE_ALWAYS_INLINE static Metadata* For(Data* data) { - ABSL_ASSERT(data != nullptr); - return reinterpret_cast(data); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE static const Metadata* For(const Data* data) { - ABSL_ASSERT(data != nullptr); - return reinterpret_cast(data); - } - - Kind kind() const { - ABSL_ASSERT(!IsNull()); - return static_cast( - ((IsStoredInline() - ? *reinterpret_cast(this) - : reference_count()->load(std::memory_order_relaxed)) >> + static Kind Kind(const Data& data) { + ABSL_ASSERT(!IsNull(data)); + return static_cast( + ((IsStoredInline(data) + ? VirtualPointer(data) + : ReferenceCount(data).load(std::memory_order_relaxed)) >> kKindShift) & kKindMask); } - DataLocality locality() const { + static DataLocality Locality(const Data& data) { // We specifically do not use `IsArenaAllocated()` and // `IsReferenceCounted()` here due to performance reasons. This code is // called often in handle implementations. - return IsNull() ? DataLocality::kNull - : IsStoredInline() ? DataLocality::kStoredInline - : ((reference_count()->load(std::memory_order_relaxed) & + return IsNull(data) ? DataLocality::kNull + : IsStoredInline(data) ? DataLocality::kStoredInline + : ((ReferenceCount(data).load(std::memory_order_relaxed) & kArenaAllocated) != kArenaAllocated) ? DataLocality::kReferenceCounted : DataLocality::kArenaAllocated; } - ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsNull() const { - return *reinterpret_cast(this) == 0; - } + static bool IsNull(const Data& data) { return VirtualPointer(data) == 0; } - ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsStoredInline() const { - return (*reinterpret_cast(this) & kStoredInline) == - kStoredInline; + static bool IsStoredInline(const Data& data) { + return (VirtualPointer(data) & kStoredInline) == kStoredInline; } - bool IsArenaAllocated() const { - return !IsNull() && !IsStoredInline() && + static bool IsArenaAllocated(const Data& data) { + return !IsNull(data) && !IsStoredInline(data) && // We use relaxed because the top 8 bits are never mutated during // reference counting and that is all we care about. - (reference_count()->load(std::memory_order_relaxed) & + (ReferenceCount(data).load(std::memory_order_relaxed) & kArenaAllocated) == kArenaAllocated; } - bool IsReferenceCounted() const { - return !IsNull() && !IsStoredInline() && + static bool IsReferenceCounted(const Data& data) { + return !IsNull(data) && !IsStoredInline(data) && // We use relaxed because the top 8 bits are never mutated during // reference counting and that is all we care about. - (reference_count()->load(std::memory_order_relaxed) & + (ReferenceCount(data).load(std::memory_order_relaxed) & kArenaAllocated) != kArenaAllocated; } - void Ref() const { - ABSL_ASSERT(IsReferenceCounted()); + static void Ref(const Data& data) { + ABSL_ASSERT(IsReferenceCounted(data)); const auto count = - (reference_count()->fetch_add(1, std::memory_order_relaxed)) & + (ReferenceCount(data).fetch_add(1, std::memory_order_relaxed)) & kReferenceCountMask; ABSL_ASSERT(count > 0 && count < kReferenceCountMax); } - bool Unref() const { - ABSL_ASSERT(IsReferenceCounted()); + static bool Unref(const Data& data) { + ABSL_ASSERT(IsReferenceCounted(data)); const auto count = - (reference_count()->fetch_sub(1, std::memory_order_seq_cst)) & + (ReferenceCount(data).fetch_sub(1, std::memory_order_seq_cst)) & kReferenceCountMask; ABSL_ASSERT(count > 0 && count < kReferenceCountMax); return count == 1; } - bool IsUnique() const { - ABSL_ASSERT(IsReferenceCounted()); - return ((reference_count()->fetch_add(1, std::memory_order_acquire)) & + static bool IsUnique(const Data& data) { + ABSL_ASSERT(IsReferenceCounted(data)); + return ((ReferenceCount(data).fetch_add(1, std::memory_order_acquire)) & kReferenceCountMask) == 1; } - bool IsTriviallyCopyable() const { - ABSL_ASSERT(IsStoredInline()); - return (*reinterpret_cast(this) & kTriviallyCopyable) == - kTriviallyCopyable; + static bool IsTriviallyCopyable(const Data& data) { + ABSL_ASSERT(IsStoredInline(data)); + return (VirtualPointer(data) & kTriviallyCopyable) == kTriviallyCopyable; } - bool IsTriviallyDestructible() const { - ABSL_ASSERT(IsStoredInline()); - return (*reinterpret_cast(this) & - kTriviallyDestructible) == kTriviallyDestructible; + static bool IsTriviallyDestructible(const Data& data) { + ABSL_ASSERT(IsStoredInline(data)); + return (VirtualPointer(data) & kTriviallyDestructible) == + kTriviallyDestructible; } // Used by `MemoryManager::New()`. - void SetArenaAllocated() { - reference_count()->fetch_or(kArenaAllocated, std::memory_order_relaxed); + static void SetArenaAllocated(const Data& data) { + ReferenceCount(data).fetch_or(kArenaAllocated, std::memory_order_relaxed); } // Used by `MemoryManager::New()`. - void SetReferenceCounted() { - reference_count()->fetch_or(kReferenceCounted, std::memory_order_relaxed); + static void SetReferenceCounted(const Data& data) { + ReferenceCount(data).fetch_or(kReferenceCounted, std::memory_order_relaxed); } private: - std::atomic* reference_count() const { - return reinterpret_cast*>( - const_cast(reinterpret_cast(this) + 1)); + static uintptr_t VirtualPointer(const Data& data) { + // The vptr, or equivalent, is stored at offset 0. Inform the compiler that + // `data` is aligned to at least `uintptr_t`. + return *reinterpret_cast( + internal::assume_aligned(&data)); + } + + static std::atomic& ReferenceCount(const Data& data) { + // For arena allocated and reference counted, the reference count + // immediately follows the vptr, or equivalent, at offset 0. So its offset + // is `sizeof(uintptr_t)`. Inform the compiler that `data` is aligned to at + // least `uintptr_t` and `std::atomic`. + return *reinterpret_cast*>( + internal::assume_aligned)>( + const_cast(reinterpret_cast(&data) + + sizeof(uintptr_t)))); } Metadata() = delete; + Metadata(const Metadata&) = delete; + Metadata(Metadata&&) = delete; + Metadata& operator=(const Metadata&) = delete; + Metadata& operator=(Metadata&&) = delete; }; template @@ -272,7 +293,7 @@ union alignas(Align) AnyDataStorage final { AnyDataStorage() : pointer(0) {} uintptr_t pointer; - char buffer[Size]; + uint8_t buffer[Size]; }; // Struct capable of storing data directly or a pointer to data. This is used by @@ -287,80 +308,84 @@ struct AnyData { static_assert(Align >= alignof(uintptr_t), "Align must be at least alignof(uintptr_t)"); - using Storage = AnyDataStorage; + static constexpr size_t kSize = Size; + static constexpr size_t kAlign = Align; + + using Storage = AnyDataStorage; Kind kind() const { ABSL_ASSERT(!IsNull()); - return Metadata::For(get())->kind(); + return Metadata::Kind(*get()); } DataLocality locality() const { - return storage.pointer == 0 ? DataLocality::kNull - : (storage.pointer & kStoredInline) == kStoredInline + return pointer() == 0 ? DataLocality::kNull + : (pointer() & kStoredInline) == kStoredInline ? DataLocality::kStoredInline - : (storage.pointer & kPointerArenaAllocated) == - kPointerArenaAllocated + : (pointer() & kPointerArenaAllocated) == kPointerArenaAllocated ? DataLocality::kArenaAllocated : DataLocality::kReferenceCounted; } - bool IsNull() const { return storage.pointer == 0; } + bool IsNull() const { return pointer() == 0; } bool IsStoredInline() const { - return (storage.pointer & kStoredInline) == kStoredInline; + return (pointer() & kStoredInline) == kStoredInline; } bool IsArenaAllocated() const { - return (storage.pointer & kPointerArenaAllocated) == kPointerArenaAllocated; + return (pointer() & kPointerArenaAllocated) == kPointerArenaAllocated; } bool IsReferenceCounted() const { - return storage.pointer != 0 && - (storage.pointer & (kStoredInline | kPointerArenaAllocated)) == 0; + return pointer() != 0 && + (pointer() & (kStoredInline | kPointerArenaAllocated)) == 0; } void Ref() const { ABSL_ASSERT(IsReferenceCounted()); - Metadata::For(get())->Ref(); + Metadata::Ref(*get()); } bool Unref() const { ABSL_ASSERT(IsReferenceCounted()); - return Metadata::For(get())->Unref(); + return Metadata::Unref(*get()); } bool IsUnique() const { ABSL_ASSERT(IsReferenceCounted()); - return Metadata::For(get())->IsUnique(); + return Metadata::IsUnique(*get()); } bool IsTriviallyCopyable() const { ABSL_ASSERT(IsStoredInline()); - return Metadata::For(get())->IsTriviallyCopyable(); + return Metadata::IsTriviallyCopyable(*get()); } bool IsTriviallyDestructible() const { ABSL_ASSERT(IsStoredInline()); - return Metadata::For(get())->IsTriviallyDestructible(); + return Metadata::IsTriviallyDestructible(*get()); } // IMPORTANT: Do not use `Metadata::For(get())` unless you know what you are // doing, instead us the method of the same name in this class. - ABSL_ATTRIBUTE_ALWAYS_INLINE Data* get() const { - return (storage.pointer & kStoredInline) == kStoredInline - ? reinterpret_cast( - const_cast(&storage.pointer)) - : reinterpret_cast(storage.pointer & kPointerMask); + Data* get() const { + // We launder to ensure the compiler does not make any assumptions about the + // content of storage in regards to const. + return internal::launder( + (pointer() & kStoredInline) == kStoredInline + ? reinterpret_cast(const_cast(buffer())) + : reinterpret_cast(pointer() & kPointerMask)); } // Copy the bytes from other, similar to `std::memcpy`. void CopyFrom(const AnyData& other) { - std::memcpy(&storage.buffer[0], &other.storage.buffer[0], Size); + std::memcpy(buffer(), other.buffer(), kSize); } // Move the bytes from other, similar to `std::memcpy` and `std::memset`. void MoveFrom(AnyData& other) { - std::memcpy(&storage.buffer[0], &other.storage.buffer[0], Size); + CopyFrom(other); other.Clear(); } @@ -373,7 +398,7 @@ struct AnyData { void Clear() { // We only need to clear the first `sizeof(uintptr_t)` bytes as that is // consulted to determine locality. - storage.pointer = 0; + pointer() = 0; } // Counterpart to `Metadata::SetArenaAllocated()` and @@ -382,18 +407,45 @@ struct AnyData { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(&data)) >= 2); // Assert pointer alignment results in at least the 2 least // significant bits being unset. - storage.pointer = - reinterpret_cast(&data) | - (Metadata::For(&data)->IsArenaAllocated() ? kPointerArenaAllocated : 0); + pointer() = reinterpret_cast(&data) | + (Metadata::IsArenaAllocated(data) ? kPointerArenaAllocated : 0); } template void ConstructInline(Args&&... args) { - ::new (&storage.buffer[0]) T(std::forward(args)...); - ABSL_ASSERT(absl::countr_zero(storage.pointer) == + ::new (buffer()) T(std::forward(args)...); + ABSL_ASSERT(absl::countr_zero(pointer()) == 0); // Assert the least significant bit is set. } + uint8_t* buffer() { + // We launder because `storage.pointer` is technically the active member by + // default and we want to ensure the compiler does not make any assumptions + // based on this. + return &internal::launder(&storage)->buffer[0]; + } + + const uint8_t* buffer() const { + // We launder because `storage.pointer` is technically the active member by + // default and we want to ensure the compiler does not make any assumptions + // based on this. + return &internal::launder(&storage)->buffer[0]; + } + + uintptr_t& pointer() { + // We launder because `storage.pointer` is technically the active member by + // default and we want to ensure the compiler does not make any assumptions + // based on this. + return internal::launder(&storage)->pointer; + } + + const uintptr_t& pointer() const { + // We launder because `storage.pointer` is technically the active member by + // default and we want to ensure the compiler does not make any assumptions + // based on this. + return internal::launder(&storage)->pointer; + } + Storage storage; }; diff --git a/base/internal/managed_memory.h b/base/internal/managed_memory.h index 366a2b014..dd7211b76 100644 --- a/base/internal/managed_memory.h +++ b/base/internal/managed_memory.h @@ -150,19 +150,18 @@ class ManagedMemory final { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(pointer)) >= 2); pointer_ = reinterpret_cast(pointer) | - (Metadata::For(pointer)->IsArenaAllocated() ? kPointerArenaAllocated - : 0); + (Metadata::IsArenaAllocated(*pointer) ? kPointerArenaAllocated : 0); } void Ref() const { if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0) { - Metadata::For(this)->Ref(); + Metadata::Ref(**this); } } void Unref() const { if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0 && - Metadata::For(get())->Unref()) { + Metadata::Unref(**this)) { delete static_cast(get()); } } diff --git a/base/memory_manager.h b/base/memory_manager.h index aa7b56841..67a76e00e 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -48,7 +48,7 @@ class MemoryManager { "T must only be stored inline"); if (!allocation_only_) { T* pointer = new T(std::forward(args)...); - base_internal::Metadata::For(pointer)->SetReferenceCounted(); + base_internal::Metadata::SetReferenceCounted(*pointer); return ManagedMemory(pointer); } void* pointer = Allocate(sizeof(T), alignof(T)); @@ -57,8 +57,7 @@ class MemoryManager { OwnDestructor(pointer, &base_internal::MemoryManagerDestructor::Destruct); } - base_internal::Metadata::For(reinterpret_cast(pointer)) - ->SetArenaAllocated(); + base_internal::Metadata::SetArenaAllocated(*reinterpret_cast(pointer)); return ManagedMemory(reinterpret_cast(pointer)); } diff --git a/base/type.cc b/base/type.cc index cecdfa73d..164a30b84 100644 --- a/base/type.cc +++ b/base/type.cc @@ -38,6 +38,7 @@ #include "base/types/timestamp_type.h" #include "base/types/type_type.h" #include "base/types/uint_type.h" +#include "internal/unreachable.h" namespace cel { @@ -232,7 +233,7 @@ void PersistentTypeHandle::CopyFrom(const PersistentTypeHandle& other) { !other.data_.IsTriviallyCopyable())) { // Type currently has only trivially copyable inline // representations. - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } else { // We can simply just copy the bytes. data_.CopyFrom(other.data_); @@ -249,7 +250,7 @@ void PersistentTypeHandle::MoveFrom(PersistentTypeHandle& other) { !other.data_.IsTriviallyCopyable())) { // Type currently has only trivially copyable inline // representations. - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } else { // We can simply just copy the bytes. data_.MoveFrom(other.data_); @@ -276,7 +277,7 @@ void PersistentTypeHandle::Destruct() { if (ABSL_PREDICT_FALSE(!data_.IsTriviallyDestructible())) { // Type currently has only trivially destructible inline // representations. - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } break; case DataLocality::kReferenceCounted: @@ -302,7 +303,7 @@ void PersistentTypeHandle::Delete() const { delete static_cast(static_cast(data_.get())); break; default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } diff --git a/base/type.h b/base/type.h index f3a9cb62c..df43b9278 100644 --- a/base/type.h +++ b/base/type.h @@ -52,7 +52,7 @@ class Type : public base_internal::Data { static bool Is(const Type& type ABSL_ATTRIBUTE_UNUSED) { return true; } // Returns the type kind. - Kind kind() const { return base_internal::Metadata::For(this)->kind(); } + Kind kind() const { return base_internal::Metadata::Kind(*this); } // Returns the type name, i.e. "list". absl::string_view name() const; @@ -133,7 +133,7 @@ class PersistentTypeHandle final { return *this; } - Type* get() const { return static_cast(data_.get()); } + Type* get() const { return reinterpret_cast(data_.get()); } explicit operator bool() const { return !data_.IsNull(); } @@ -270,7 +270,8 @@ class SimpleType : public Type, public InlineData { static bool Is(const Type& type) { return type.kind() == kKind; } - SimpleType() = default; + constexpr SimpleType() : InlineData(kMetadata) {} + SimpleType(const SimpleType&) = default; SimpleType(SimpleType&&) = default; SimpleType& operator=(const SimpleType&) = default; @@ -291,11 +292,9 @@ class SimpleType : public Type, public InlineData { private: friend class PersistentTypeHandle; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = kStoredInline | kTriviallyCopyable | kTriviallyDestructible | (static_cast(kKind) << kKindShift); - - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; }; } // namespace base_internal diff --git a/base/value.cc b/base/value.cc index a16a490bb..593a51c3d 100644 --- a/base/value.cc +++ b/base/value.cc @@ -34,6 +34,7 @@ #include "base/values/timestamp_value.h" #include "base/values/type_value.h" #include "base/values/uint_value.h" +#include "internal/unreachable.h" namespace cel { @@ -72,7 +73,7 @@ const Persistent& Value::type() const { case Kind::kStruct: return static_cast(this)->type().As(); default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } @@ -109,7 +110,7 @@ std::string Value::DebugString() const { case Kind::kStruct: return static_cast(this)->DebugString(); default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } @@ -148,7 +149,7 @@ void Value::HashValue(absl::HashState state) const { case Kind::kStruct: return static_cast(this)->HashValue(std::move(state)); default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } @@ -188,7 +189,7 @@ bool Value::Equals(const Value& other) const { case Kind::kStruct: return static_cast(this)->Equals(other); default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } @@ -240,7 +241,7 @@ void PersistentValueHandle::CopyFrom(const PersistentValueHandle& other) { *static_cast(other.data_.get())); break; default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } else { // We can simply just copy the bytes. @@ -278,7 +279,7 @@ void PersistentValueHandle::MoveFrom(PersistentValueHandle& other) { std::move(*static_cast(other.data_.get()))); break; default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } other.Destruct(); other.data_.Clear(); @@ -323,7 +324,7 @@ void PersistentValueHandle::Destruct() { data_.Destruct(); break; default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } break; @@ -353,7 +354,7 @@ void PersistentValueHandle::Delete() const { delete static_cast(static_cast(data_.get())); break; default: - ABSL_INTERNAL_UNREACHABLE; + internal::unreachable(); } } diff --git a/base/value.h b/base/value.h index e878e59aa..35154f514 100644 --- a/base/value.h +++ b/base/value.h @@ -56,7 +56,7 @@ class Value : public base_internal::Data { // Returns the kind of the value. This is equivalent to `type().kind()` but // faster in many scenarios. As such it should be preffered when only the kind // is required. - Kind kind() const { return base_internal::Metadata::For(this)->kind(); } + Kind kind() const { return base_internal::Metadata::Kind(*this); } // Returns the type of the value. If you only need the kind, prefer `kind()`. const Persistent& type() const; @@ -144,7 +144,7 @@ class PersistentValueHandle final { return *this; } - Value* get() const { return static_cast(data_.get()); } + Value* get() const { return reinterpret_cast(data_.get()); } explicit operator bool() const { return !data_.IsNull(); } @@ -212,7 +212,7 @@ class SimpleValue : public Value, InlineData { static bool Is(const Value& value) { return value.kind() == kKind; } - explicit SimpleValue(U value) : value_(value) {} + explicit SimpleValue(U value) : InlineData(kMetadata), value_(value) {} SimpleValue(const SimpleValue&) = default; SimpleValue(SimpleValue&&) = default; @@ -237,13 +237,12 @@ class SimpleValue : public Value, InlineData { private: friend class PersistentValueHandle; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = kStoredInline | (std::is_trivially_copyable_v ? kTriviallyCopyable : 0) | (std::is_trivially_destructible_v ? kTriviallyDestructible : 0) | (static_cast(kKind) << kKindShift); - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; U value_; }; @@ -254,7 +253,7 @@ class SimpleValue : public Value, InlineData { static bool Is(const Value& value) { return value.kind() == kKind; } - SimpleValue() {} + constexpr SimpleValue() : InlineData(kMetadata) {} SimpleValue(const SimpleValue&) = default; SimpleValue(SimpleValue&&) = default; @@ -274,11 +273,9 @@ class SimpleValue : public Value, InlineData { private: friend class PersistentValueHandle; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = kStoredInline | kTriviallyCopyable | kTriviallyDestructible | (static_cast(kKind) << kKindShift); - - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; }; } // namespace base_internal diff --git a/base/values/bytes_value.cc b/base/values/bytes_value.cc index f8f94cef6..99b29a0f5 100644 --- a/base/values/bytes_value.cc +++ b/base/values/bytes_value.cc @@ -207,11 +207,11 @@ std::string BytesValue::ToString() const { } absl::Cord BytesValue::ToCord() const { - switch (base_internal::Metadata::For(this)->locality()) { + switch (base_internal::Metadata::Locality(*this)) { case base_internal::DataLocality::kNull: return absl::Cord(); case base_internal::DataLocality::kStoredInline: - if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + if (base_internal::Metadata::IsTriviallyCopyable(*this)) { return absl::MakeCordFromExternal( static_cast(this) ->value_, @@ -221,11 +221,11 @@ absl::Cord BytesValue::ToCord() const { ->value_; } case base_internal::DataLocality::kReferenceCounted: - base_internal::Metadata::For(this)->Ref(); + base_internal::Metadata::Ref(*this); return absl::MakeCordFromExternal( static_cast(this)->value_, [this]() { - if (base_internal::Metadata::For(this)->Unref()) { + if (base_internal::Metadata::Unref(*this)) { delete static_cast(this); } }); @@ -252,11 +252,11 @@ void BytesValue::HashValue(absl::HashState state) const { } base_internal::BytesValueRep BytesValue::rep() const { - switch (base_internal::Metadata::For(this)->locality()) { + switch (base_internal::Metadata::Locality(*this)) { case base_internal::DataLocality::kNull: return base_internal::BytesValueRep(); case base_internal::DataLocality::kStoredInline: - if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + if (base_internal::Metadata::IsTriviallyCopyable(*this)) { return base_internal::BytesValueRep( absl::in_place_type, static_cast(this) diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index e8bc3dfaa..e82af88b4 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -110,18 +110,18 @@ class InlinedCordBytesValue final : public BytesValue, template friend class AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); - explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} + explicit InlinedCordBytesValue(absl::Cord value) + : base_internal::InlineData(kMetadata), value_(std::move(value)) {} InlinedCordBytesValue(const InlinedCordBytesValue&) = default; InlinedCordBytesValue(InlinedCordBytesValue&&) = default; InlinedCordBytesValue& operator=(const InlinedCordBytesValue&) = default; InlinedCordBytesValue& operator=(InlinedCordBytesValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Cord value_; }; @@ -136,13 +136,13 @@ class InlinedStringViewBytesValue final : public BytesValue, template friend class AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | base_internal::kTriviallyCopyable | base_internal::kTriviallyDestructible | (static_cast(kKind) << base_internal::kKindShift); explicit InlinedStringViewBytesValue(absl::string_view value) - : value_(value) {} + : base_internal::InlineData(kMetadata), value_(value) {} InlinedStringViewBytesValue(const InlinedStringViewBytesValue&) = default; InlinedStringViewBytesValue(InlinedStringViewBytesValue&&) = default; @@ -151,7 +151,6 @@ class InlinedStringViewBytesValue final : public BytesValue, InlinedStringViewBytesValue& operator=(InlinedStringViewBytesValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::string_view value_; }; diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 4004dcd43..ad6bb2c3c 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -64,14 +64,15 @@ class EnumValue final : public Value, public base_internal::InlineData { template friend class base_internal::AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); EnumValue(Persistent type, int64_t number) - : type_(std::move(type)), number_(number) {} + : base_internal::InlineData(kMetadata), + type_(std::move(type)), + number_(number) {} - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; Persistent type_; int64_t number_; }; diff --git a/base/values/error_value.h b/base/values/error_value.h index 4807ddb40..34a93307b 100644 --- a/base/values/error_value.h +++ b/base/values/error_value.h @@ -51,18 +51,18 @@ class ErrorValue final : public Value, public base_internal::InlineData { template friend class base_internal::AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); - explicit ErrorValue(absl::Status value) : value_(std::move(value)) {} + explicit ErrorValue(absl::Status value) + : base_internal::InlineData(kMetadata), value_(std::move(value)) {} ErrorValue(const ErrorValue&) = default; ErrorValue(ErrorValue&&) = default; ErrorValue& operator=(const ErrorValue&) = default; ErrorValue& operator=(ErrorValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Status value_; }; diff --git a/base/values/string_value.cc b/base/values/string_value.cc index e4795492a..c14185ea6 100644 --- a/base/values/string_value.cc +++ b/base/values/string_value.cc @@ -209,11 +209,11 @@ std::string StringValue::ToString() const { } absl::Cord StringValue::ToCord() const { - switch (base_internal::Metadata::For(this)->locality()) { + switch (base_internal::Metadata::Locality(*this)) { case base_internal::DataLocality::kNull: return absl::Cord(); case base_internal::DataLocality::kStoredInline: - if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + if (base_internal::Metadata::IsTriviallyCopyable(*this)) { return absl::MakeCordFromExternal( static_cast( this) @@ -224,11 +224,11 @@ absl::Cord StringValue::ToCord() const { ->value_; } case base_internal::DataLocality::kReferenceCounted: - base_internal::Metadata::For(this)->Ref(); + base_internal::Metadata::Ref(*this); return absl::MakeCordFromExternal( static_cast(this)->value_, [this]() { - if (base_internal::Metadata::For(this)->Unref()) { + if (base_internal::Metadata::Unref(*this)) { delete static_cast(this); } }); @@ -255,11 +255,11 @@ void StringValue::HashValue(absl::HashState state) const { } base_internal::StringValueRep StringValue::rep() const { - switch (base_internal::Metadata::For(this)->locality()) { + switch (base_internal::Metadata::Locality(*this)) { case base_internal::DataLocality::kNull: return base_internal::StringValueRep(); case base_internal::DataLocality::kStoredInline: - if (base_internal::Metadata::For(this)->IsTriviallyCopyable()) { + if (base_internal::Metadata::IsTriviallyCopyable(*this)) { return base_internal::StringValueRep( absl::in_place_type, static_cast( diff --git a/base/values/string_value.h b/base/values/string_value.h index abb85e2aa..5f1062827 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -111,19 +111,18 @@ class InlinedCordStringValue final : public StringValue, template friend class AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); explicit InlinedCordStringValue(absl::Cord value) - : value_(std::move(value)) {} + : base_internal::InlineData(kMetadata), value_(std::move(value)) {} InlinedCordStringValue(const InlinedCordStringValue&) = default; InlinedCordStringValue(InlinedCordStringValue&&) = default; InlinedCordStringValue& operator=(const InlinedCordStringValue&) = default; InlinedCordStringValue& operator=(InlinedCordStringValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::Cord value_; }; @@ -139,13 +138,13 @@ class InlinedStringViewStringValue final : public StringValue, template friend class AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | base_internal::kTriviallyCopyable | base_internal::kTriviallyDestructible | (static_cast(kKind) << base_internal::kKindShift); explicit InlinedStringViewStringValue(absl::string_view value) - : value_(value) {} + : base_internal::InlineData(kMetadata), value_(value) {} InlinedStringViewStringValue(const InlinedStringViewStringValue&) = default; InlinedStringViewStringValue(InlinedStringViewStringValue&&) = default; @@ -154,7 +153,6 @@ class InlinedStringViewStringValue final : public StringValue, InlinedStringViewStringValue& operator=(InlinedStringViewStringValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; absl::string_view value_; }; diff --git a/base/values/type_value.h b/base/values/type_value.h index a1b3c3ed6..58abaac87 100644 --- a/base/values/type_value.h +++ b/base/values/type_value.h @@ -50,18 +50,18 @@ class TypeValue final : public Value, public base_internal::InlineData { template friend class base_internal::AnyData; - static constexpr uintptr_t kVirtualPointer = + static constexpr uintptr_t kMetadata = base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); - explicit TypeValue(Persistent value) : value_(std::move(value)) {} + explicit TypeValue(Persistent value) + : base_internal::InlineData(kMetadata), value_(std::move(value)) {} TypeValue(const TypeValue&) = default; TypeValue(TypeValue&&) = default; TypeValue& operator=(const TypeValue&) = default; TypeValue& operator=(TypeValue&&) = default; - uintptr_t vptr_ ABSL_ATTRIBUTE_UNUSED = kVirtualPointer; Persistent value_; }; diff --git a/internal/BUILD b/internal/BUILD index 2faad5b31..03c0b2b55 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -16,6 +16,33 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "assume_aligned", + hdrs = ["assume_aligned.h"], + deps = [ + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "unreachable", + hdrs = ["unreachable.h"], + deps = [ + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "launder", + hdrs = ["launder.h"], + deps = [ + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + ], +) + cc_library( name = "benchmark", testonly = True, diff --git a/internal/assume_aligned.h b/internal/assume_aligned.h new file mode 100644 index 000000000..1dcfcf0dd --- /dev/null +++ b/internal/assume_aligned.h @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_ASSUME_ALIGNED_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_ASSUME_ALIGNED_H_ + +#include // std::assume_aligned in C++20 + +#include "absl/base/attributes.h" +#include "absl/base/config.h" + +namespace cel::internal { + +// C++14 version of C++20's std::assume_aligned(). +template +ABSL_MUST_USE_RESULT inline T* assume_aligned(T* pointer) noexcept { +#if defined(__cpp_lib_assume_aligned) && __cpp_lib_assume_aligned >= 201811L + return std::assume_aligned(pointer); +#elif (defined(__GNUC__) && !defined(__clang__)) || \ + ABSL_HAVE_BUILTIN(__builtin_assume_aligned) + return static_cast(__builtin_assume_aligned(pointer, N)); +#else + return pointer; +#endif +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_ASSUME_ALIGNED_H_ diff --git a/internal/launder.h b/internal/launder.h new file mode 100644 index 000000000..2f3807dfc --- /dev/null +++ b/internal/launder.h @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_LAUNDER_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_LAUNDER_H_ + +#if __cplusplus >= 201703L +#include +#endif + +#include "absl/base/attributes.h" +#include "absl/base/config.h" + +namespace cel::internal { + +// C++14 version of C++17's std::launder(). +template +ABSL_MUST_USE_RESULT inline T* launder(T* pointer) noexcept { +#if __cplusplus >= 201703L + return std::launder(pointer); +#elif ABSL_HAVE_BUILTIN(__builtin_launder) || \ + (defined(__GNUC__) && __GNUC__ >= 7) + return __builtin_launder(pointer); +#else + // Fallback to undefined behavior. + return pointer; +#endif +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_LAUNDER_H_ diff --git a/internal/unreachable.h b/internal/unreachable.h new file mode 100644 index 000000000..5b72c3582 --- /dev/null +++ b/internal/unreachable.h @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_UNREACHABLE_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_UNREACHABLE_H_ + +#include +#include // std::unreachable in C++20 + +#include "absl/base/attributes.h" +#include "absl/base/config.h" + +namespace cel::internal { + +// C++14 version of C++20's std::unreachable(). +ABSL_ATTRIBUTE_NORETURN inline void unreachable() noexcept { +#if defined(__cpp_lib_unreachable) && __cpp_lib_unreachable >= 202202L + std::unreachable(); +#elif defined(__GNUC__) || ABSL_HAVE_BUILTIN(__builtin_unreachable) + __builtin_unreachable(); +#elif defined(_MSC_VER) + __assume(false); +#else + std::abort(); +#endif +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_UNREACHABLE_H_ From 367dc5a6fbbbade7f3c0e694e22f6663e91a5ab0 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 29 Jun 2022 22:43:24 +0000 Subject: [PATCH 009/303] Update `google::api::expr::runtime::CelMap::ListKeys` to return `absl::StatusOr` PiperOrigin-RevId: 458081210 --- eval/eval/comprehension_step.cc | 8 ++++++-- eval/public/BUILD | 2 -- eval/public/builtin_func_test.cc | 8 ++++++-- eval/public/cel_attribute_test.cc | 4 +++- eval/public/cel_value.cc | 2 +- eval/public/cel_value.h | 2 +- eval/public/cel_value_test.cc | 4 +++- eval/public/comparison_functions.cc | 6 +++++- eval/public/containers/container_backed_map_impl.h | 4 +++- eval/public/containers/field_backed_map_impl_test.cc | 2 +- eval/public/containers/internal_field_backed_map_impl.cc | 4 +++- eval/public/containers/internal_field_backed_map_impl.h | 2 +- .../containers/internal_field_backed_map_impl_test.cc | 2 +- eval/public/set_util.cc | 4 ++-- eval/public/structs/cel_proto_wrap_util.cc | 6 ++++-- eval/public/structs/cel_proto_wrap_util_test.cc | 2 +- eval/public/structs/cel_proto_wrapper_test.cc | 2 +- eval/public/structs/proto_message_type_adapter.cc | 2 +- eval/public/transform_utility.cc | 6 +++--- eval/public/value_export_util.cc | 2 +- eval/tests/benchmark_test.cc | 4 +++- tools/flatbuffers_backed_impl.cc | 2 +- tools/flatbuffers_backed_impl.h | 2 +- tools/flatbuffers_backed_impl_test.cc | 8 ++++---- 24 files changed, 56 insertions(+), 34 deletions(-) diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 64b98f058..6a3d1ec3b 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -2,6 +2,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -232,8 +233,11 @@ absl::Status ListKeysStep::ProjectKeys(ExecutionFrame* frame) const { } const CelValue& map = frame->value_stack().Peek(); - frame->value_stack().PopAndPush( - CelValue::CreateList(map.MapOrDie()->ListKeys())); + auto list_keys = map.MapOrDie()->ListKeys(); + if (!list_keys.ok()) { + return std::move(list_keys).status(); + } + frame->value_stack().PopAndPush(CelValue::CreateList(*list_keys)); return absl::OkStatus(); } diff --git a/eval/public/BUILD b/eval/public/BUILD index a123004cd..3243b3e60 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -543,13 +543,11 @@ cc_test( ":unknown_attribute_set", ":unknown_set", "//base:memory_manager", - "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", - "//internal:no_destructor", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index d0065e788..75c09cb3a 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1088,7 +1088,9 @@ class FakeErrorMap : public CelMap { return absl::nullopt; } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } }; template @@ -1120,7 +1122,9 @@ class FakeMap : public CelMap { return it->second; } - const CelList* ListKeys() const override { return keys_.get(); } + absl::StatusOr ListKeys() const override { + return keys_.get(); + } private: std::map data_; diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc index 7bd09c640..d89d1074b 100644 --- a/eval/public/cel_attribute_test.cc +++ b/eval/public/cel_attribute_test.cc @@ -27,7 +27,9 @@ class DummyMap : public CelMap { absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } int size() const override { return 0; } }; diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 4dc5bcc77..746a6b498 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -97,7 +97,7 @@ struct DebugStringVisitor { } std::string operator()(const CelMap* arg) { - const CelList* keys = arg->ListKeys(); + const CelList* keys = arg->ListKeys().value(); std::vector elements; elements.reserve(keys->size()); for (int i = 0; i < keys->size(); i++) { diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index fe5a6f1dd..727d9af44 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -578,7 +578,7 @@ class CelMap { // Return list of keys. CelList is owned by Arena, so no // ownership is passed. - virtual const CelList* ListKeys() const = 0; + virtual absl::StatusOr ListKeys() const = 0; virtual ~CelMap() {} }; diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 683518563..13a1e2108 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -28,7 +28,9 @@ class DummyMap : public CelMap { absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } int size() const override { return 0; } }; diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 649d66a5c..c05509733 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -243,7 +243,11 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { return false; } - const CelList* keys = t1->ListKeys(); + auto list_keys = t1->ListKeys(); + if (!list_keys.ok()) { + return absl::nullopt; + } + const CelList* keys = *list_keys; for (int i = 0; i < keys->size(); i++) { CelValue key = (*keys)[i]; CelValue v1 = (*t1)[key].value(); diff --git a/eval/public/containers/container_backed_map_impl.h b/eval/public/containers/container_backed_map_impl.h index ea1976715..2a0e96933 100644 --- a/eval/public/containers/container_backed_map_impl.h +++ b/eval/public/containers/container_backed_map_impl.h @@ -30,7 +30,9 @@ class CelMapBuilder : public CelMap { return values_map_.contains(cel_key); } - const CelList* ListKeys() const override { return &key_list_; } + absl::StatusOr ListKeys() const override { + return &key_list_; + } private: // Custom CelList implementation for maintaining key list. diff --git a/eval/public/containers/field_backed_map_impl_test.cc b/eval/public/containers/field_backed_map_impl_test.cc index 1cf711851..69d446017 100644 --- a/eval/public/containers/field_backed_map_impl_test.cc +++ b/eval/public/containers/field_backed_map_impl_test.cc @@ -223,7 +223,7 @@ TEST(FieldBackedMapImplTest, KeyListTest) { google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); EXPECT_EQ(key_list->size(), 100); for (int i = 0; i < key_list->size(); i++) { diff --git a/eval/public/containers/internal_field_backed_map_impl.cc b/eval/public/containers/internal_field_backed_map_impl.cc index 4eabb99ad..d37caed93 100644 --- a/eval/public/containers/internal_field_backed_map_impl.cc +++ b/eval/public/containers/internal_field_backed_map_impl.cc @@ -156,7 +156,9 @@ int FieldBackedMapImpl::size() const { return reflection_->FieldSize(*message_, descriptor_); } -const CelList* FieldBackedMapImpl::ListKeys() const { return key_list_.get(); } +absl::StatusOr FieldBackedMapImpl::ListKeys() const { + return key_list_.get(); +} absl::StatusOr FieldBackedMapImpl::Has(const CelValue& key) const { #ifdef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND diff --git a/eval/public/containers/internal_field_backed_map_impl.h b/eval/public/containers/internal_field_backed_map_impl.h index ae43a5e4c..ec773d9d2 100644 --- a/eval/public/containers/internal_field_backed_map_impl.h +++ b/eval/public/containers/internal_field_backed_map_impl.h @@ -43,7 +43,7 @@ class FieldBackedMapImpl : public CelMap { // Presence test function. absl::StatusOr Has(const CelValue& key) const override; - const CelList* ListKeys() const override; + absl::StatusOr ListKeys() const override; protected: // These methods are exposed as protected methods for testing purposes since diff --git a/eval/public/containers/internal_field_backed_map_impl_test.cc b/eval/public/containers/internal_field_backed_map_impl_test.cc index 392b84f35..60b77ab3d 100644 --- a/eval/public/containers/internal_field_backed_map_impl_test.cc +++ b/eval/public/containers/internal_field_backed_map_impl_test.cc @@ -274,7 +274,7 @@ TEST(FieldBackedMapImplTest, KeyListTest) { google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); EXPECT_EQ(key_list->size(), 100); for (int i = 0; i < key_list->size(); i++) { diff --git a/eval/public/set_util.cc b/eval/public/set_util.cc index 43c9e37a3..39d0c1298 100644 --- a/eval/public/set_util.cc +++ b/eval/public/set_util.cc @@ -68,8 +68,8 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { lhs_keys.reserve(lhs->size()); rhs_keys.reserve(lhs->size()); - const CelList* lhs_key_view = lhs->ListKeys(); - const CelList* rhs_key_view = rhs->ListKeys(); + const CelList* lhs_key_view = lhs->ListKeys().value(); + const CelList* rhs_key_view = rhs->ListKeys().value(); for (int i = 0; i < lhs->size(); i++) { lhs_keys.push_back(lhs_key_view->operator[](i)); diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index 8d03bee87..ffbb0aadf 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -133,7 +133,9 @@ class DynamicMap : public CelMap { int size() const override { return values_->fields_size(); } - const CelList* ListKeys() const override { return &key_list_; } + absl::StatusOr ListKeys() const override { + return &key_list_; + } private: // List of keys in Struct.fields map. @@ -549,7 +551,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Struct* json_ return nullptr; } const CelMap& map = *value.MapOrDie(); - const auto& keys = *map.ListKeys(); + const auto& keys = *map.ListKeys().value(); auto fields = json_struct->mutable_fields(); for (int i = 0; i < keys.size(); i++) { auto k = keys[i]; diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc index 8611ef254..c78b186d0 100644 --- a/eval/public/structs/cel_proto_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -292,7 +292,7 @@ TEST_F(CelProtoWrapperTest, UnwrapMessageToValueStruct) { ASSERT_OK(missing_field_presence); EXPECT_FALSE(*missing_field_presence); - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); ASSERT_EQ(key_list->size(), kFields.size()); std::vector result_keys; diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index b9a7fefde..adbc98b64 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -282,7 +282,7 @@ TEST_F(CelProtoWrapperTest, UnwrapValueStruct) { ASSERT_OK(missing_field_presence); EXPECT_FALSE(*missing_field_presence); - const CelList* key_list = cel_map->ListKeys(); + const CelList* key_list = cel_map->ListKeys().value(); ASSERT_EQ(key_list->size(), kFields.size()); std::vector result_keys; diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index ccd2cded8..ebf97155b 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -312,7 +312,7 @@ absl::Status ProtoMessageTypeAdapter::SetField( ValidateSetFieldOp(value_field_descriptor != nullptr, field_name, "failed to find value field descriptor")); - const CelList* key_list = cel_map->ListKeys(); + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys()); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list)[i]; diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 1a5cd5d6e..2206ff36b 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -85,9 +85,9 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { case CelValue::Type::kMap: { auto* map_value = result->mutable_map_value(); auto& cel_map = *value.MapOrDie(); - const auto& keys = *cel_map.ListKeys(); - for (int i = 0; i < keys.size(); ++i) { - CelValue key = keys[i]; + CEL_ASSIGN_OR_RETURN(const auto* keys, cel_map.ListKeys()); + for (int i = 0; i < keys->size(); ++i) { + CelValue key = (*keys)[i]; auto* entry = map_value->add_entries(); CEL_RETURN_IF_ERROR(CelValueToValue(key, entry->mutable_key())); auto optional_value = cel_map[key]; diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index 481c3301c..30cb067f8 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -121,7 +121,7 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { } case CelValue::Type::kMap: { const CelMap* cel_map = in_value.MapOrDie(); - auto keys_list = cel_map->ListKeys(); + CEL_ASSIGN_OR_RETURN(auto keys_list, cel_map->ListKeys()); auto out_values = out_value->mutable_struct_value()->mutable_fields(); for (int i = 0; i < keys_list->size(); i++) { std::string key; diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index 220bcb1d7..a59157413 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -294,7 +294,9 @@ class RequestMap : public CelMap { return {}; } int size() const override { return 3; } - const CelList* ListKeys() const override { return nullptr; } + absl::StatusOr ListKeys() const override { + return absl::UnimplementedError("CelMap::ListKeys is not implemented"); + } }; // Uses a lazily constructed map container for "ip", "path", and "token". diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 8462333a5..10c0b1cb8 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -130,7 +130,7 @@ class ObjectStringIndexedMapImpl : public CelMap { return absl::nullopt; } - const CelList* ListKeys() const override { return &keys_; } + absl::StatusOr ListKeys() const override { return &keys_; } private: struct KeyList : public CelList { diff --git a/tools/flatbuffers_backed_impl.h b/tools/flatbuffers_backed_impl.h index 86374a0be..e9ea9f29c 100644 --- a/tools/flatbuffers_backed_impl.h +++ b/tools/flatbuffers_backed_impl.h @@ -24,7 +24,7 @@ class FlatBuffersMapImpl : public CelMap { absl::optional operator[](CelValue cel_key) const override; - const CelList* ListKeys() const override { return &keys_; } + absl::StatusOr ListKeys() const override { return &keys_; } private: struct FieldList : public CelList { diff --git a/tools/flatbuffers_backed_impl_test.cc b/tools/flatbuffers_backed_impl_test.cc index 9f55f793a..55589bfd5 100644 --- a/tools/flatbuffers_backed_impl_test.cc +++ b/tools/flatbuffers_backed_impl_test.cc @@ -71,7 +71,7 @@ class FlatBuffersTest : public testing::Test { parser_.builder_.GetBufferPointer(), *schema_, &arena_); EXPECT_NE(nullptr, value); EXPECT_EQ(kNumFields, value->size()); - const CelList* keys = value->ListKeys(); + const CelList* keys = value->ListKeys().value(); EXPECT_NE(nullptr, keys); EXPECT_EQ(kNumFields, keys->size()); EXPECT_TRUE((*keys)[2].IsString()); @@ -496,7 +496,7 @@ TEST_F(FlatBuffersTest, VectorFieldDefaults) { EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(0, m.size()); - EXPECT_EQ(0, m.ListKeys()->size()); + EXPECT_EQ(0, (*m.ListKeys())->size()); } { @@ -533,7 +533,7 @@ TEST_F(FlatBuffersTest, IndexedObjectVectorField) { EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(4, m.size()); - const CelList& l = *m.ListKeys(); + const CelList& l = *m.ListKeys().value(); EXPECT_EQ(4, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_TRUE(l[1].IsString()); @@ -591,7 +591,7 @@ TEST_F(FlatBuffersTest, IndexedObjectVectorFieldDefaults) { const CelMap& m = *f->MapOrDie(); EXPECT_EQ(1, m.size()); - const CelList& l = *m.ListKeys(); + const CelList& l = *m.ListKeys().value(); EXPECT_EQ(1, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_EQ("", l[0].StringOrDie().value()); From 0493847719ec31f71a5884a2ab3886197f20dc91 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 1 Jul 2022 17:51:29 +0000 Subject: [PATCH 010/303] Internal change PiperOrigin-RevId: 458502411 --- base/value_test.cc | 5 +++++ base/values/map_value.h | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/base/value_test.cc b/base/value_test.cc index 23fa9c575..89846faea 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -413,6 +413,11 @@ class TestMapValue final : public MapValue { return absl::StrCat("{", absl::StrJoin(parts, ", "), "}"); } + absl::StatusOr> ListKeys( + ValueFactory& value_factory) const override { + return absl::UnimplementedError("MapValue::ListKeys is not implemented"); + } + const std::map& value() const { return entries_; } private: diff --git a/base/values/map_value.h b/base/values/map_value.h index d7f5f32ac..846297ae4 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -33,6 +33,7 @@ namespace cel { +class ListValue; class ValueFactory; // MapValue represents an instance of cel::MapType. @@ -63,6 +64,9 @@ class MapValue : public Value, public base_internal::HeapData { virtual absl::StatusOr Has( const Persistent& key) const = 0; + virtual absl::StatusOr> ListKeys( + ValueFactory& value_factory) const = 0; + protected: explicit MapValue(Persistent type); From df1bbdfbcb2f372303a6938e1b9c540c00480cb6 Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 14 Jul 2022 16:53:44 +0000 Subject: [PATCH 011/303] gcc fix. PiperOrigin-RevId: 460985644 --- base/internal/data.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/internal/data.h b/base/internal/data.h index 188e4bbf7..f2d8aa905 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -174,7 +174,7 @@ static_assert(sizeof(HeapData) == sizeof(uintptr_t) * 2, // Provides introspection for `Data`. class Metadata final { public: - static Kind Kind(const Data& data) { + static ::cel::Kind Kind(const Data& data) { ABSL_ASSERT(!IsNull(data)); return static_cast( ((IsStoredInline(data) From fc206f91a79d1570d3adb24aa0f36f241fdd1881 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 15 Jul 2022 21:13:18 +0000 Subject: [PATCH 012/303] Remove `google::api::expr::Expr` from `CelAttribute` PiperOrigin-RevId: 461236369 --- ...builder_short_circuiting_conformance_test.cc | 14 ++++++-------- eval/compiler/flat_expr_builder_test.cc | 9 +++------ eval/eval/attribute_trail.cc | 4 ++-- eval/eval/comprehension_step_test.cc | 2 +- eval/eval/create_struct_step_test.cc | 2 +- eval/eval/evaluator_core_test.cc | 4 ++-- eval/eval/function_step_test.cc | 2 +- eval/eval/select_step_test.cc | 2 +- eval/eval/ternary_step_test.cc | 5 +++-- eval/public/builtin_func_test.cc | 2 +- eval/public/cel_attribute.cc | 10 +++------- eval/public/cel_attribute.h | 17 ++++++++++++----- eval/public/structs/cel_proto_wrap_util.cc | 6 +++--- eval/tests/unknowns_end_to_end_test.cc | 2 +- 14 files changed, 40 insertions(+), 41 deletions(-) diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index 83afcc396..708f1cb2a 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -252,8 +252,7 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable().ident_expr().name(), - testing::Eq("var1")); + EXPECT_THAT(attrs.attributes()[0]->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, UnknownOr) { @@ -285,8 +284,7 @@ TEST_P(ShortCircuitingTest, UnknownOr) { const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable().ident_expr().name(), - testing::Eq("var1")); + EXPECT_THAT(attrs.attributes()[0]->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, BasicTernary) { @@ -369,7 +367,7 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { const auto& attrs = result.UnknownSetOrDie()->unknown_attributes().attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs[0]->variable().ident_expr().name(), Eq("cond")); + EXPECT_THAT(attrs[0]->variable_name(), Eq("cond")); // Unknown branches are discarded if condition is unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {}), @@ -382,7 +380,7 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { const auto& attrs2 = result.UnknownSetOrDie()->unknown_attributes().attributes(); ASSERT_THAT(attrs2, SizeIs(1)); - EXPECT_THAT(attrs2[0]->variable().ident_expr().name(), Eq("cond")); + EXPECT_THAT(attrs2[0]->variable_name(), Eq("cond")); } TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { @@ -418,7 +416,7 @@ TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { const auto& attrs3 = result.UnknownSetOrDie()->unknown_attributes().attributes(); ASSERT_THAT(attrs3, SizeIs(1)); - EXPECT_EQ(attrs3[0]->variable().ident_expr().name(), "arg2"); + EXPECT_EQ(attrs3[0]->variable_name(), "arg2"); } TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { @@ -456,7 +454,7 @@ TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { const auto& attrs = result.UnknownSetOrDie()->unknown_attributes().attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_EQ(attrs[0]->variable().ident_expr().name(), "cond"); + EXPECT_EQ(attrs[0]->variable_name(), "cond"); } const char* TestName(testing::TestParamInfo info) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 8797d8608..404005195 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1706,12 +1706,9 @@ TEST(FlatExprBuilderTest, Ternary) { ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(1)); - EXPECT_THAT(result_set->unknown_attributes() - .attributes()[0] - ->variable() - .ident_expr() - .name(), - Eq("selector")); + EXPECT_THAT( + result_set->unknown_attributes().attributes()[0]->variable_name(), + Eq("selector")); } } diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc index f623b7fea..6cbe0069a 100644 --- a/eval/eval/attribute_trail.cc +++ b/eval/eval/attribute_trail.cc @@ -24,8 +24,8 @@ AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, std::vector qualifiers = attribute_->qualifier_path(); qualifiers.push_back(qualifier); - auto attribute = - manager.New(attribute_->variable(), std::move(qualifiers)); + auto attribute = manager.New( + std::string(attribute_->variable_name()), std::move(qualifiers)); return AttributeTrail(attribute.release()); } diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 5ee42109b..38d3b2340 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -158,7 +158,7 @@ TEST_F(ListKeysStepTest, MapPartiallyUnknown) { eval_result->UnknownSetOrDie()->unknown_attributes().attributes(); EXPECT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs.at(0)->variable().ident_expr().name(), Eq("var")); + EXPECT_THAT(attrs.at(0)->variable_name(), Eq("var")); EXPECT_THAT(attrs.at(0)->qualifier_path(), SizeIs(0)); } diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index a3190acfd..53bbb4547 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -64,7 +64,7 @@ absl::StatusOr RunExpression(absl::string_view field, create_struct->set_message_name("google.api.expr.runtime.TestMessage"); auto entry = create_struct->add_entries(); - entry->set_field_key(std::string(field)); + entry->set_field_key(field.data(), field.size()); auto adapter = type_registry.FindTypeAdapter(create_struct->message_name()); if (!adapter.has_value() || adapter->mutation_apis() == nullptr) { diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 129ef5785..5c8dd7450 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -132,8 +132,8 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { int64_t result_value; ASSERT_TRUE(result.GetValue(&result_value)); EXPECT_EQ(test_value, result_value); - ASSERT_TRUE(trail->attribute()->variable().has_ident_expr()); - ASSERT_EQ(trail->attribute()->variable().ident_expr().name(), "var"); + ASSERT_TRUE(trail->attribute()->has_variable_name()); + ASSERT_EQ(trail->attribute()->variable_name(), "var"); // Test that it goes away properly ASSERT_OK(frame.ClearIterVar()); diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 3fcec6dc1..39b0db4f0 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -56,7 +56,7 @@ class ConstFunction : public CelFunction { static Expr::Call MakeCall(absl::string_view name) { Expr::Call call; - call.set_function(std::string(name)); + call.set_function(name.data(), name.size()); call.clear_target(); return call; } diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 8cf94d609..40acab62e 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -77,7 +77,7 @@ absl::StatusOr RunExpression(const CelValue target, Expr dummy_expr; auto select = dummy_expr.mutable_select_expr(); - select->set_field(std::string(field)); + select->set_field(field.data(), field.size()); select->set_test_only(test); Expr* expr0 = select->mutable_operand(); diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index b89512d7c..e0ccb44e1 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -145,7 +145,8 @@ TEST_F(LogicStepTest, TestUnknownHandling) { auto ident_expr1 = expr1.mutable_ident_expr(); ident_expr1->set_name("name1"); - CelAttribute attr0(expr0, {}), attr1(expr1, {}); + CelAttribute attr0(expr0.ident_expr().name(), {}), + attr1(expr1.ident_expr().name(), {}); UnknownAttributeSet unknown_attr_set0({&attr0}); UnknownAttributeSet unknown_attr_set1({&attr1}); UnknownSet unknown_set0(unknown_attr_set0); @@ -161,7 +162,7 @@ TEST_F(LogicStepTest, TestUnknownHandling) { const auto& attrs = result.UnknownSetOrDie()->unknown_attributes().attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); - EXPECT_THAT(attrs[0]->variable().ident_expr().name(), Eq("name0")); + EXPECT_THAT(attrs[0]->variable_name(), Eq("name0")); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 75c09cb3a..8b1a7d378 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -68,7 +68,7 @@ class BuiltinsTest : public ::testing::Test { Expr expr; SourceInfo source_info; auto call = expr.mutable_call_expr(); - call->set_function(std::string(operation)); + call->set_function(operation.data(), operation.size()); if (target.has_value()) { std::string param_name = "target"; diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index c7c26c95a..2498183f2 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -169,11 +169,7 @@ CelAttributePattern CreateCelAttributePattern( bool CelAttribute::operator==(const CelAttribute& other) const { // TODO(issues/41) we only support Ident-rooted attributes at the moment. - if (!variable().has_ident_expr() || !other.variable().has_ident_expr()) { - return false; - } - - if (variable().ident_expr().name() != other.variable().ident_expr().name()) { + if (variable_name() != other.variable_name()) { return false; } @@ -191,12 +187,12 @@ bool CelAttribute::operator==(const CelAttribute& other) const { } const absl::StatusOr CelAttribute::AsString() const { - if (variable_.ident_expr().name().empty()) { + if (variable_name().empty()) { return absl::InvalidArgumentError( "Only ident rooted attributes are supported."); } - std::string result = variable_.ident_expr().name(); + std::string result = std::string(variable_name()); for (const auto& qualifier : qualifier_path_) { CEL_RETURN_IF_ERROR( diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index afe8fab87..fa9a29a37 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/status.h" @@ -151,12 +152,18 @@ class CelAttributeQualifierPattern { // CelAttribute represents resolved attribute path. class CelAttribute { public: - CelAttribute(google::api::expr::v1alpha1::Expr variable, + CelAttribute(std::string variable_name, std::vector qualifier_path) - : variable_(std::move(variable)), + : variable_name_(std::move(variable_name)), qualifier_path_(std::move(qualifier_path)) {} - const google::api::expr::v1alpha1::Expr& variable() const { return variable_; } + CelAttribute(const google::api::expr::v1alpha1::Expr& variable, + std::vector qualifier_path) + : CelAttribute(variable.ident_expr().name(), std::move(qualifier_path)) {} + + absl::string_view variable_name() const { return variable_name_; } + + bool has_variable_name() const { return !variable_name_.empty(); } const std::vector& qualifier_path() const { return qualifier_path_; @@ -167,7 +174,7 @@ class CelAttribute { const absl::StatusOr AsString() const; private: - google::api::expr::v1alpha1::Expr variable_; + std::string variable_name_; std::vector qualifier_path_; }; @@ -200,7 +207,7 @@ class CelAttributePattern { // Distinguishes between no-match, partial match and full match cases. MatchType IsMatch(const CelAttribute& attribute) const { MatchType result = MatchType::NONE; - if (attribute.variable().ident_expr().name() != variable_) { + if (attribute.variable_name() != variable_) { return result; } diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index ffbb0aadf..02752042c 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -434,7 +434,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, BytesValue* w if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(std::string(view_val.value())); + wrapper->set_value(view_val.value().data(), view_val.value().size()); return wrapper; } @@ -492,7 +492,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, StringValue* if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(std::string(view_val.value())); + wrapper->set_value(view_val.value().data(), view_val.value().size()); return wrapper; } @@ -629,7 +629,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json) case CelValue::Type::kString: { CelValue::StringHolder val; if (value.GetValue(&val)) { - json->set_string_value(std::string(val.value())); + json->set_string_value(val.value().data(), val.value().size()); return json; } } break; diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index cd873ea51..5e4d969c2 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -168,7 +168,7 @@ MATCHER_P(FunctionCallIs, fn_name, "") { MATCHER_P(AttributeIs, attr, "") { const CelAttribute* result = arg; - return result->variable().ident_expr().name() == attr; + return result->variable_name() == attr; } TEST_F(UnknownsTest, NoUnknowns) { From 2a7856e71548aed89a409add073e7c8cb524b128 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 18 Jul 2022 22:27:04 +0000 Subject: [PATCH 013/303] Add native typed API for AstTraverse PiperOrigin-RevId: 461729503 --- base/ast.h | 74 ++++ base/ast_test.cc | 40 ++ eval/public/BUILD | 70 ++++ eval/public/ast_traverse_native.cc | 345 ++++++++++++++++ eval/public/ast_traverse_native.h | 66 ++++ eval/public/ast_traverse_native_test.cc | 438 +++++++++++++++++++++ eval/public/ast_visitor_native.h | 130 ++++++ eval/public/ast_visitor_native_base.h | 94 +++++ eval/public/source_position_native.cc | 66 ++++ eval/public/source_position_native.h | 62 +++ eval/public/source_position_native_test.cc | 108 +++++ 11 files changed, 1493 insertions(+) create mode 100644 eval/public/ast_traverse_native.cc create mode 100644 eval/public/ast_traverse_native.h create mode 100644 eval/public/ast_traverse_native_test.cc create mode 100644 eval/public/ast_visitor_native.h create mode 100644 eval/public/ast_visitor_native_base.h create mode 100644 eval/public/source_position_native.cc create mode 100644 eval/public/source_position_native.h create mode 100644 eval/public/source_position_native_test.cc diff --git a/base/ast.h b/base/ast.h index da735861d..4c8322f31 100644 --- a/base/ast.h +++ b/base/ast.h @@ -371,6 +371,15 @@ class CreateStruct { const Expr& map_key() const; + Expr& mutable_map_key() { + auto* value = absl::get_if>(&key_kind_); + if (value != nullptr) { + if (*value != nullptr) return **value; + } + key_kind_.emplace>(std::make_unique()); + return *absl::get>(key_kind_); + } + bool has_value() const { return value_ != nullptr; } const Expr& value() const; @@ -411,6 +420,8 @@ class CreateStruct { std::vector& mutable_entries() { return entries_; } + const std::string& message_name() const { return message_name_; } + bool operator==(const CreateStruct& other) const { return message_name_ == other.message_name_ && entries_ == other.entries_; } @@ -672,6 +683,15 @@ class Expr { return *default_constant; } + Constant& mutable_const_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + const Ident& ident_expr() const { auto* value = absl::get_if(&expr_kind_); if (value != nullptr) { @@ -681,6 +701,15 @@ class Expr { return *default_ident; } + Ident& mutable_ident_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + const Select& select_expr() const { auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(expr_kind_); + } + const Call& call_expr() const { auto* value = absl::get_if(&expr_kind_); if (value != nullptr) { @@ -699,6 +737,15 @@ class Expr { return *default_call; } + Call& mutable_call_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + const CreateList& list_expr() const { auto* value = absl::get_if(&expr_kind_); if (value != nullptr) { @@ -708,6 +755,15 @@ class Expr { return *default_create_list; } + CreateList& mutable_list_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + const CreateStruct& struct_expr() const { auto* value = absl::get_if(&expr_kind_); if (value != nullptr) { @@ -717,6 +773,15 @@ class Expr { return *default_create_struct; } + CreateStruct& mutable_struct_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + const Comprehension& comprehension_expr() const { auto* value = absl::get_if(&expr_kind_); if (value != nullptr) { @@ -726,6 +791,15 @@ class Expr { return *default_comprehension; } + Comprehension& mutable_comprehension_expr() { + auto* value = absl::get_if(&expr_kind_); + if (value != nullptr) { + return *value; + } + expr_kind_.emplace(); + return absl::get(expr_kind_); + } + bool operator==(const Expr& other) const { return id_ == other.id_ && expr_kind_ == other.expr_kind_; } diff --git a/base/ast_test.cc b/base/ast_test.cc index 222a1d216..a2d188188 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -172,6 +172,16 @@ TEST(AstTest, CreateStructEntryMutableValue) { ASSERT_EQ(absl::get(entry.value().expr_kind()).name(), "var"); } +TEST(AstTest, CreateStructEntryMutableMapKey) { + CreateStruct::Entry entry; + entry.mutable_map_key().set_expr_kind(Ident("key")); + ASSERT_TRUE(absl::holds_alternative(entry.map_key().expr_kind())); + ASSERT_EQ(absl::get(entry.map_key().expr_kind()).name(), "key"); + entry.mutable_map_key().set_expr_kind(Ident("new_key")); + ASSERT_TRUE(absl::holds_alternative(entry.map_key().expr_kind())); + ASSERT_EQ(absl::get(entry.map_key().expr_kind()).name(), "new_key"); +} + TEST(AstTest, ExprConstructionComprehension) { Comprehension comprehension; comprehension.set_iter_var("iter_var"); @@ -385,6 +395,36 @@ TEST(AstTest, TypeComparatorTest) { EXPECT_FALSE(type.type() == Type()); } +TEST(AstTest, ExprMutableConstruction) { + Expr expr; + expr.mutable_const_expr().set_constant_kind(true); + ASSERT_TRUE(expr.has_const_expr()); + EXPECT_TRUE(expr.const_expr().bool_value()); + expr.mutable_ident_expr().set_name("expr"); + ASSERT_TRUE(expr.has_ident_expr()); + EXPECT_FALSE(expr.has_const_expr()); + EXPECT_EQ(expr.ident_expr().name(), "expr"); + expr.mutable_select_expr().set_field("field"); + ASSERT_TRUE(expr.has_select_expr()); + EXPECT_FALSE(expr.has_ident_expr()); + EXPECT_EQ(expr.select_expr().field(), "field"); + expr.mutable_call_expr().set_function("function"); + ASSERT_TRUE(expr.has_call_expr()); + EXPECT_FALSE(expr.has_select_expr()); + EXPECT_EQ(expr.call_expr().function(), "function"); + expr.mutable_list_expr(); + EXPECT_TRUE(expr.has_list_expr()); + EXPECT_FALSE(expr.has_call_expr()); + expr.mutable_struct_expr().set_message_name("name"); + ASSERT_TRUE(expr.has_struct_expr()); + EXPECT_EQ(expr.struct_expr().message_name(), "name"); + EXPECT_FALSE(expr.has_list_expr()); + expr.mutable_comprehension_expr().set_accu_var("accu_var"); + ASSERT_TRUE(expr.has_comprehension_expr()); + EXPECT_FALSE(expr.has_list_expr()); + EXPECT_EQ(expr.comprehension_expr().accu_var(), "accu_var"); +} + } // namespace } // namespace internal } // namespace ast diff --git a/eval/public/BUILD b/eval/public/BUILD index 3243b3e60..73d99dee4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -431,6 +431,15 @@ cc_library( ], ) +cc_library( + name = "source_position_native", + srcs = ["source_position_native.cc"], + hdrs = ["source_position_native.h"], + deps = [ + "//base:ast", + ], +) + cc_library( name = "ast_visitor", hdrs = [ @@ -453,6 +462,27 @@ cc_library( ], ) +cc_library( + name = "ast_visitor_native", + hdrs = [ + "ast_visitor_native.h", + ], + deps = [ + ":source_position_native", + "//base:ast", + ], +) + +cc_library( + name = "ast_visitor_native_base", + hdrs = [ + "ast_visitor_native_base.h", + ], + deps = [ + ":ast_visitor_native", + ], +) + cc_library( name = "ast_traverse", srcs = [ @@ -469,6 +499,22 @@ cc_library( ], ) +cc_library( + name = "ast_traverse_native", + srcs = [ + "ast_traverse_native.cc", + ], + hdrs = [ + "ast_traverse_native.h", + ], + deps = [ + ":ast_visitor_native", + ":source_position_native", + "//base:ast", + "@com_google_absl//absl/types:variant", + ], +) + cc_library( name = "cel_options", hdrs = [ @@ -605,6 +651,18 @@ cc_test( ], ) +cc_test( + name = "ast_traverse_native_test", + srcs = [ + "ast_traverse_native_test.cc", + ], + deps = [ + ":ast_traverse_native", + ":ast_visitor_native", + "//internal:testing", + ], +) + cc_library( name = "ast_rewrite", srcs = [ @@ -798,6 +856,18 @@ cc_test( ], ) +cc_test( + name = "source_position_native_test", + size = "small", + srcs = [ + "source_position_native_test.cc", + ], + deps = [ + ":source_position_native", + "//internal:testing", + ], +) + cc_test( name = "unknown_attribute_set_test", size = "small", diff --git a/eval/public/ast_traverse_native.cc b/eval/public/ast_traverse_native.cc new file mode 100644 index 000000000..498a2aed6 --- /dev/null +++ b/eval/public/ast_traverse_native.cc @@ -0,0 +1,345 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/ast_traverse_native.h" + +#include + +#include "absl/types/variant.h" +#include "base/ast.h" +#include "eval/public/ast_visitor_native.h" +#include "eval/public/source_position_native.h" + +namespace cel::ast::internal { + +namespace { + +struct ArgRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; + + // For records that are direct arguments to call, we need to call + // the CallArg visitor immediately after the argument is evaluated. + const Expr* calling_expr; + int call_arg; +}; + +struct ComprehensionRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; + + const Comprehension* comprehension; + const Expr* comprehension_expr; + ComprehensionArg comprehension_arg; + bool use_comprehension_callbacks; +}; + +struct ExprRecord { + // Not null. + const Expr* expr; + // Not null. + const SourceInfo* source_info; +}; + +using StackRecordKind = + absl::variant; + +struct StackRecord { + public: + ABSL_ATTRIBUTE_UNUSED static constexpr int kNotCallArg = -1; + static constexpr int kTarget = -2; + + StackRecord(const Expr* e, const SourceInfo* info) { + ExprRecord record; + record.expr = e; + record.source_info = info; + record_variant = record; + } + + StackRecord(const Expr* e, const SourceInfo* info, + const Comprehension* comprehension, + const Expr* comprehension_expr, + ComprehensionArg comprehension_arg, + bool use_comprehension_callbacks) { + if (use_comprehension_callbacks) { + ComprehensionRecord record; + record.expr = e; + record.source_info = info; + record.comprehension = comprehension; + record.comprehension_expr = comprehension_expr; + record.comprehension_arg = comprehension_arg; + record.use_comprehension_callbacks = use_comprehension_callbacks; + record_variant = record; + return; + } + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = comprehension_expr; + record.call_arg = comprehension_arg; + record_variant = record; + } + + StackRecord(const Expr* e, const SourceInfo* info, const Expr* call, + int argnum) { + ArgRecord record; + record.expr = e; + record.source_info = info; + record.calling_expr = call; + record.call_arg = argnum; + record_variant = record; + } + StackRecordKind record_variant; + bool visited = false; +}; + +struct PreVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitExpr(expr, &position); + if (expr->has_select_expr()) { + visitor->PreVisitSelect(&expr->select_expr(), expr, &position); + } else if (expr->has_call_expr()) { + visitor->PreVisitCall(&expr->call_expr(), expr, &position); + } else if (expr->has_comprehension_expr()) { + visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, + &position); + } else { + // No pre-visit action. + } + } + + // Do nothing for Arg variant. + void operator()(const ArgRecord&) {} + + void operator()(const ComprehensionRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PreVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PreVisitor{visitor}, record.record_variant); +} + +struct PostVisitor { + void operator()(const ExprRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + struct { + AstVisitor* visitor; + const Expr* expr; + const SourcePosition& position; + void operator()(const Constant& constant) { + visitor->PostVisitConst(&expr->const_expr(), expr, &position); + } + void operator()(const Ident& ident) { + visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); + } + void operator()(const Select& select) { + visitor->PostVisitSelect(&expr->select_expr(), expr, &position); + } + void operator()(const Call& call) { + visitor->PostVisitCall(&expr->call_expr(), expr, &position); + } + void operator()(const CreateList& create_list) { + visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); + } + void operator()(const CreateStruct& create_struct) { + visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); + } + void operator()(const Comprehension& comprehension) { + visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, + &position); + } + } handler{visitor, record.expr, + SourcePosition(expr->id(), record.source_info)}; + absl::visit(handler, record.expr->expr_kind()); + + visitor->PostVisitExpr(expr, &position); + } + + void operator()(const ArgRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + if (record.call_arg == StackRecord::kTarget) { + visitor->PostVisitTarget(record.calling_expr, &position); + } else { + visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); + } + } + + void operator()(const ComprehensionRecord& record) { + const Expr* expr = record.expr; + const SourcePosition position(expr->id(), record.source_info); + visitor->PostVisitComprehensionSubexpression( + expr, record.comprehension, record.comprehension_arg, &position); + } + + AstVisitor* visitor; +}; + +void PostVisit(const StackRecord& record, AstVisitor* visitor) { + absl::visit(PostVisitor{visitor}, record.record_variant); +} + +void PushSelectDeps(const Select* select_expr, const SourceInfo* source_info, + std::stack* stack) { + if (select_expr->has_operand()) { + stack->push(StackRecord(&select_expr->operand(), source_info)); + } +} + +void PushCallDeps(const Call* call_expr, const Expr* expr, + const SourceInfo* source_info, + std::stack* stack) { + const int arg_size = call_expr->args().size(); + // Our contract is that we visit arguments in order. To do that, we need + // to push them onto the stack in reverse order. + for (int i = arg_size - 1; i >= 0; --i) { + stack->push(StackRecord(&call_expr->args()[i], source_info, expr, i)); + } + // Are we receiver-style? + if (call_expr->has_target()) { + stack->push(StackRecord(&call_expr->target(), source_info, expr, + StackRecord::kTarget)); + } +} + +void PushListDeps(const CreateList* list_expr, const SourceInfo* source_info, + std::stack* stack) { + const auto& elements = list_expr->elements(); + for (auto it = elements.rbegin(); it != elements.rend(); ++it) { + const auto& element = *it; + stack->push(StackRecord(&element, source_info)); + } +} + +void PushStructDeps(const CreateStruct* struct_expr, + const SourceInfo* source_info, + std::stack* stack) { + const auto& entries = struct_expr->entries(); + for (auto it = entries.rbegin(); it != entries.rend(); ++it) { + const auto& entry = *it; + // The contract is to visit key, then value. So put them on the stack + // in the opposite order. + if (entry.has_value()) { + stack->push(StackRecord(&entry.value(), source_info)); + } + + if (entry.has_map_key()) { + stack->push(StackRecord(&entry.map_key(), source_info)); + } + } +} + +void PushComprehensionDeps(const Comprehension* c, const Expr* expr, + const SourceInfo* source_info, + std::stack* stack, + bool use_comprehension_callbacks) { + StackRecord iter_range(&c->iter_range(), source_info, c, expr, ITER_RANGE, + use_comprehension_callbacks); + StackRecord accu_init(&c->accu_init(), source_info, c, expr, ACCU_INIT, + use_comprehension_callbacks); + StackRecord loop_condition(&c->loop_condition(), source_info, c, expr, + LOOP_CONDITION, use_comprehension_callbacks); + StackRecord loop_step(&c->loop_step(), source_info, c, expr, LOOP_STEP, + use_comprehension_callbacks); + StackRecord result(&c->result(), source_info, c, expr, RESULT, + use_comprehension_callbacks); + // Push them in reverse order. + stack->push(result); + stack->push(loop_step); + stack->push(loop_condition); + stack->push(accu_init); + stack->push(iter_range); +} + +struct PushDepsVisitor { + void operator()(const ExprRecord& record) { + struct { + std::stack& stack; + const TraversalOptions& options; + const ExprRecord& record; + void operator()(const Constant& constant) {} + void operator()(const Ident& ident) {} + void operator()(const Select& select) { + PushSelectDeps(&record.expr->select_expr(), record.source_info, &stack); + } + void operator()(const Call& call) { + PushCallDeps(&record.expr->call_expr(), record.expr, record.source_info, + &stack); + } + void operator()(const CreateList& create_list) { + PushListDeps(&record.expr->list_expr(), record.source_info, &stack); + } + void operator()(const CreateStruct& create_struct) { + PushStructDeps(&record.expr->struct_expr(), record.source_info, &stack); + } + void operator()(const Comprehension& comprehension) { + PushComprehensionDeps(&record.expr->comprehension_expr(), record.expr, + record.source_info, &stack, + options.use_comprehension_callbacks); + } + } handler{stack, options, record}; + absl::visit(handler, record.expr->expr_kind()); + } + + void operator()(const ArgRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); + } + + void operator()(const ComprehensionRecord& record) { + stack.push(StackRecord(record.expr, record.source_info)); + } + + std::stack& stack; + const TraversalOptions& options; +}; + +void PushDependencies(const StackRecord& record, std::stack& stack, + const TraversalOptions& options) { + absl::visit(PushDepsVisitor{stack, options}, record.record_variant); +} + +} // namespace + +void AstTraverse(const Expr* expr, const SourceInfo* source_info, + AstVisitor* visitor, TraversalOptions options) { + std::stack stack; + stack.push(StackRecord(expr, source_info)); + + while (!stack.empty()) { + StackRecord& record = stack.top(); + if (!record.visited) { + PreVisit(record, visitor); + PushDependencies(record, stack, options); + record.visited = true; + } else { + PostVisit(record, visitor); + stack.pop(); + } + } +} + +} // namespace cel::ast::internal diff --git a/eval/public/ast_traverse_native.h b/eval/public/ast_traverse_native.h new file mode 100644 index 000000000..d65c052e7 --- /dev/null +++ b/eval/public/ast_traverse_native.h @@ -0,0 +1,66 @@ +/* + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_NATIVE_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_NATIVE_H_ + +#include "base/ast.h" +#include "eval/public/ast_visitor_native.h" + +namespace cel::ast::internal { + +struct TraversalOptions { + bool use_comprehension_callbacks; + + TraversalOptions() : use_comprehension_callbacks(false) {} +}; + +// Traverses the AST representation in an expr proto. +// +// expr: root node of the tree. +// source_info: optional additional parse information about the expression +// visitor: the callback object that receives the visitation notifications +// +// Traversal order follows the pattern: +// PreVisitExpr +// ..PreVisit{ExprKind} +// ....PreVisit{ArgumentIndex} +// .......PreVisitExpr (subtree) +// .......PostVisitExpr (subtree) +// ....PostVisit{ArgumentIndex} +// ..PostVisit{ExprKind} +// PostVisitExpr +// +// Example callback order for fn(1, var): +// PreVisitExpr +// ..PreVisitCall(fn) +// ......PreVisitExpr +// ........PostVisitConst(1) +// ......PostVisitExpr +// ....PostVisitArg(fn, 0) +// ......PreVisitExpr +// ........PostVisitIdent(var) +// ......PostVisitExpr +// ....PostVisitArg(fn, 1) +// ..PostVisitCall(fn) +// PostVisitExpr +void AstTraverse(const Expr* expr, const SourceInfo* source_info, + AstVisitor* visitor, + TraversalOptions options = TraversalOptions()); + +} // namespace cel::ast::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_NATIVE_H_ diff --git a/eval/public/ast_traverse_native_test.cc b/eval/public/ast_traverse_native_test.cc new file mode 100644 index 000000000..a4a369d04 --- /dev/null +++ b/eval/public/ast_traverse_native_test.cc @@ -0,0 +1,438 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/ast_traverse_native.h" + +#include "eval/public/ast_visitor_native.h" +#include "internal/testing.h" + +namespace cel::ast::internal { + +namespace { + +using testing::_; + +class MockAstVisitor : public AstVisitor { + public: + // Expr handler. + MOCK_METHOD(void, PreVisitExpr, + (const Expr* expr, const SourcePosition* position), (override)); + + // Expr handler. + MOCK_METHOD(void, PostVisitExpr, + (const Expr* expr, const SourcePosition* position), (override)); + + MOCK_METHOD(void, PostVisitConst, + (const Constant* const_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Ident node handler. + MOCK_METHOD(void, PostVisitIdent, + (const Ident* ident_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Select node handler group + MOCK_METHOD(void, PreVisitSelect, + (const Select* select_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + MOCK_METHOD(void, PostVisitSelect, + (const Select* select_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Call node handler group + MOCK_METHOD(void, PreVisitCall, + (const Call* call_expr, const Expr* expr, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitCall, + (const Call* call_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehension, + (const Comprehension* comprehension_expr, const Expr* expr, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitComprehension, + (const Comprehension* comprehension_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // Comprehension node handler group + MOCK_METHOD(void, PreVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + MOCK_METHOD(void, PostVisitComprehensionSubexpression, + (const Expr* expr, const Comprehension* comprehension_expr, + ComprehensionArg comprehension_arg, + const SourcePosition* position), + (override)); + + // We provide finer granularity for Call and Comprehension node callbacks + // to allow special handling for short-circuiting. + MOCK_METHOD(void, PostVisitTarget, + (const Expr* expr, const SourcePosition* position), (override)); + MOCK_METHOD(void, PostVisitArg, + (int arg_num, const Expr* expr, const SourcePosition* position), + (override)); + + // CreateList node handler group + MOCK_METHOD(void, PostVisitCreateList, + (const CreateList* list_expr, const Expr* expr, + const SourcePosition* position), + (override)); + + // CreateStruct node handler group + MOCK_METHOD(void, PostVisitCreateStruct, + (const CreateStruct* struct_expr, const Expr* expr, + const SourcePosition* position), + (override)); +}; + +TEST(AstCrawlerTest, CheckCrawlConstant) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& const_expr = expr.mutable_const_expr(); + + EXPECT_CALL(handler, PostVisitConst(&const_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +TEST(AstCrawlerTest, CheckCrawlIdent) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& ident_expr = expr.mutable_ident_expr(); + + EXPECT_CALL(handler, PostVisitIdent(&ident_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Select node when operand is not set. +TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitSelect(&select_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Select node +TEST(AstCrawlerTest, CheckCrawlSelect) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& select_expr = expr.mutable_select_expr(); + auto& operand = select_expr.mutable_operand(); + auto& ident_expr = operand.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PostVisitIdent(&ident_expr, &operand, _)).Times(1); + EXPECT_CALL(handler, PostVisitSelect(&select_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Call node without receiver +TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { + SourceInfo source_info; + MockAstVisitor handler; + + // (, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + call_expr.mutable_args().reserve(2); + Expr& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + Expr& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(&call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(_, _)).Times(0); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(&const_expr, &arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(&ident_expr, &arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(&call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Call node with receiver +TEST(AstCrawlerTest, CheckCrawlCallReceiver) { + SourceInfo source_info; + MockAstVisitor handler; + + // .(, ) + Expr expr; + auto& call_expr = expr.mutable_call_expr(); + Expr& target = call_expr.mutable_target(); + auto& target_ident = target.mutable_ident_expr(); + call_expr.mutable_args().reserve(2); + Expr& arg0 = call_expr.mutable_args().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + Expr& arg1 = call_expr.mutable_args().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitCall(&call_expr, &expr, _)).Times(1); + + // Target + EXPECT_CALL(handler, PostVisitIdent(&target_ident, &target, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&target, _)).Times(1); + EXPECT_CALL(handler, PostVisitTarget(&expr, _)).Times(1); + + // Arg0 + EXPECT_CALL(handler, PostVisitConst(&const_expr, &arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); + + // Arg1 + EXPECT_CALL(handler, PostVisitIdent(&ident_expr, &arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); + + // Back to call + EXPECT_CALL(handler, PostVisitCall(&call_expr, &expr, _)).Times(1); + EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehension) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(&c, &expr, _)).Times(1); + + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(&iter_range, &c, + ITER_RANGE, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(&iter_range_expr, &iter_range, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(&iter_range, &c, + ITER_RANGE, _)) + .Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(&accu_init, &c, ACCU_INIT, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(&accu_init_expr, &accu_init, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(&accu_init, &c, ACCU_INIT, _)) + .Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PreVisitComprehensionSubexpression(&loop_condition, &c, + LOOP_CONDITION, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitConst(&loop_condition_expr, &loop_condition, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitComprehensionSubexpression(&loop_condition, &c, + LOOP_CONDITION, _)) + .Times(1); + + // LOOP STEP + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(&loop_step, &c, LOOP_STEP, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitIdent(&loop_step_expr, &loop_step, _)).Times(1); + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(&loop_step, &c, LOOP_STEP, _)) + .Times(1); + + // RESULT + EXPECT_CALL(handler, + PreVisitComprehensionSubexpression(&result, &c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitConst(&result_expr, &result, _)).Times(1); + + EXPECT_CALL(handler, + PostVisitComprehensionSubexpression(&result, &c, RESULT, _)) + .Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(&c, &expr, _)).Times(1); + + TraversalOptions opts; + opts.use_comprehension_callbacks = true; + AstTraverse(&expr, &source_info, &handler, opts); +} + +// Test handling of Comprehension node +TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& c = expr.mutable_comprehension_expr(); + auto& iter_range = c.mutable_iter_range(); + auto& iter_range_expr = iter_range.mutable_const_expr(); + auto& accu_init = c.mutable_accu_init(); + auto& accu_init_expr = accu_init.mutable_ident_expr(); + auto& loop_condition = c.mutable_loop_condition(); + auto& loop_condition_expr = loop_condition.mutable_const_expr(); + auto& loop_step = c.mutable_loop_step(); + auto& loop_step_expr = loop_step.mutable_ident_expr(); + auto& result = c.mutable_result(); + auto& result_expr = result.mutable_const_expr(); + + testing::InSequence seq; + + // Lowest level entry will be called first + EXPECT_CALL(handler, PreVisitComprehension(&c, &expr, _)).Times(1); + + EXPECT_CALL(handler, PostVisitConst(&iter_range_expr, &iter_range, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(ITER_RANGE, &expr, _)).Times(1); + + // ACCU_INIT + EXPECT_CALL(handler, PostVisitIdent(&accu_init_expr, &accu_init, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(ACCU_INIT, &expr, _)).Times(1); + + // LOOP CONDITION + EXPECT_CALL(handler, PostVisitConst(&loop_condition_expr, &loop_condition, _)) + .Times(1); + EXPECT_CALL(handler, PostVisitArg(LOOP_CONDITION, &expr, _)).Times(1); + + // LOOP STEP + EXPECT_CALL(handler, PostVisitIdent(&loop_step_expr, &loop_step, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(LOOP_STEP, &expr, _)).Times(1); + + // RESULT + EXPECT_CALL(handler, PostVisitConst(&result_expr, &result, _)).Times(1); + EXPECT_CALL(handler, PostVisitArg(RESULT, &expr, _)).Times(1); + + EXPECT_CALL(handler, PostVisitComprehension(&c, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of CreateList node. +TEST(AstCrawlerTest, CheckCreateList) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& list_expr = expr.mutable_list_expr(); + list_expr.mutable_elements().reserve(2); + auto& arg0 = list_expr.mutable_elements().emplace_back(); + auto& const_expr = arg0.mutable_const_expr(); + auto& arg1 = list_expr.mutable_elements().emplace_back(); + auto& ident_expr = arg1.mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(&const_expr, &arg0, _)).Times(1); + EXPECT_CALL(handler, PostVisitIdent(&ident_expr, &arg1, _)).Times(1); + EXPECT_CALL(handler, PostVisitCreateList(&list_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test handling of CreateStruct node. +TEST(AstCrawlerTest, CheckCreateStruct) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& struct_expr = expr.mutable_struct_expr(); + auto& entry0 = struct_expr.mutable_entries().emplace_back(); + + auto& key = entry0.mutable_map_key().mutable_const_expr(); + auto& value = entry0.mutable_value().mutable_ident_expr(); + + testing::InSequence seq; + + EXPECT_CALL(handler, PostVisitConst(&key, &entry0.map_key(), _)).Times(1); + EXPECT_CALL(handler, PostVisitIdent(&value, &entry0.value(), _)).Times(1); + EXPECT_CALL(handler, PostVisitCreateStruct(&struct_expr, &expr, _)).Times(1); + + AstTraverse(&expr, &source_info, &handler); +} + +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprHandlers) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto& struct_expr = expr.mutable_struct_expr(); + auto& entry0 = struct_expr.mutable_entries().emplace_back(); + + entry0.mutable_map_key().mutable_const_expr(); + entry0.mutable_value().mutable_ident_expr(); + + EXPECT_CALL(handler, PreVisitExpr(_, _)).Times(3); + EXPECT_CALL(handler, PostVisitExpr(_, _)).Times(3); + + AstTraverse(&expr, &source_info, &handler); +} + +} // namespace + +} // namespace cel::ast::internal diff --git a/eval/public/ast_visitor_native.h b/eval/public/ast_visitor_native.h new file mode 100644 index 000000000..5a1c253e1 --- /dev/null +++ b/eval/public/ast_visitor_native.h @@ -0,0 +1,130 @@ +/* + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_NATIVE_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_NATIVE_H_ + +#include "base/ast.h" +#include "eval/public/source_position_native.h" + +namespace cel { +namespace ast { +namespace internal { + +// ComprehensionArg specifies arg_num values passed to PostVisitArg +// for subexpressions of Comprehension. +enum ComprehensionArg { + ITER_RANGE, + ACCU_INIT, + LOOP_CONDITION, + LOOP_STEP, + RESULT, +}; + +// Callback handler class, used in conjunction with AstTraverse. +// Methods of this class are invoked when AST nodes with corresponding +// types are processed. +// +// For all types with children, the children will be visited in the natural +// order from first to last. For structs, keys are visited before values. +class AstVisitor { + public: + virtual ~AstVisitor() {} + + // Expr node handler method. Called for all Expr nodes. + // Is invoked before child Expr nodes being processed. + virtual void PreVisitExpr(const Expr*, const SourcePosition*) = 0; + + // Expr node handler method. Called for all Expr nodes. + // Is invoked after child Expr nodes are processed. + virtual void PostVisitExpr(const Expr*, const SourcePosition*) = 0; + + // Const node handler. + // Invoked after child nodes are processed. + virtual void PostVisitConst(const Constant*, const Expr*, + const SourcePosition*) = 0; + + // Ident node handler. + // Invoked after child nodes are processed. + virtual void PostVisitIdent(const Ident*, const Expr*, + const SourcePosition*) = 0; + + // Select node handler + // Invoked before child nodes are processed. + virtual void PreVisitSelect(const Select*, const Expr*, + const SourcePosition*) = 0; + + // Select node handler + // Invoked after child nodes are processed. + virtual void PostVisitSelect(const Select*, const Expr*, + const SourcePosition*) = 0; + + // Call node handler group + // We provide finer granularity for Call node callbacks to allow special + // handling for short-circuiting + // PreVisitCall is invoked before child nodes are processed. + virtual void PreVisitCall(const Call*, const Expr*, + const SourcePosition*) = 0; + + // Invoked after all child nodes are processed. + virtual void PostVisitCall(const Call*, const Expr*, + const SourcePosition*) = 0; + + // Invoked after target node is processed. + // Expr is the call expression. + virtual void PostVisitTarget(const Expr*, const SourcePosition*) = 0; + + // Invoked before all child nodes are processed. + virtual void PreVisitComprehension(const Comprehension*, const Expr*, + const SourcePosition*) = 0; + + // Invoked before comprehension child node is processed. + virtual void PreVisitComprehensionSubexpression( + const Expr* subexpr, const Comprehension* compr, + ComprehensionArg comprehension_arg, const SourcePosition*) {} + + // Invoked after comprehension child node is processed. + virtual void PostVisitComprehensionSubexpression( + const Expr* subexpr, const Comprehension* compr, + ComprehensionArg comprehension_arg, const SourcePosition*) {} + + // Invoked after all child nodes are processed. + virtual void PostVisitComprehension(const Comprehension*, const Expr*, + const SourcePosition*) = 0; + + // Invoked after each argument node processed. + // For Call arg_num is the index of the argument. + // For Comprehension arg_num is specified by ComprehensionArg. + // Expr is the call expression. + virtual void PostVisitArg(int arg_num, const Expr*, + const SourcePosition*) = 0; + + // CreateList node handler + // Invoked after child nodes are processed. + virtual void PostVisitCreateList(const CreateList*, const Expr*, + const SourcePosition*) = 0; + + // CreateStruct node handler + // Invoked after child nodes are processed. + virtual void PostVisitCreateStruct(const CreateStruct*, const Expr*, + const SourcePosition*) = 0; +}; + +} // namespace internal +} // namespace ast +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ diff --git a/eval/public/ast_visitor_native_base.h b/eval/public/ast_visitor_native_base.h new file mode 100644 index 000000000..43b8f16e7 --- /dev/null +++ b/eval/public/ast_visitor_native_base.h @@ -0,0 +1,94 @@ +/* + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ + +#include "eval/public/ast_visitor_native.h" + +namespace cel { +namespace ast { +namespace internal { + +// Trivial base implementation of AstVisitor. +class AstVisitorBase : public AstVisitor { + public: + AstVisitorBase() {} + + // Non-copyable + AstVisitorBase(const AstVisitorBase&) = delete; + AstVisitorBase& operator=(AstVisitorBase const&) = delete; + + ~AstVisitorBase() override {} + + // Const node handler. + // Invoked after child nodes are processed. + void PostVisitConst(const Constant*, const Expr*, + const SourcePosition*) override {} + + // Ident node handler. + // Invoked after child nodes are processed. + void PostVisitIdent(const Ident*, const Expr*, + const SourcePosition*) override {} + + // Select node handler + // Invoked after child nodes are processed. + void PostVisitSelect(const Select*, const Expr*, + const SourcePosition*) override {} + + // Call node handler group + // We provide finer granularity for Call node callbacks to allow special + // handling for short-circuiting + // PreVisitCall is invoked before child nodes are processed. + void PreVisitCall(const Call*, const Expr*, const SourcePosition*) override {} + + // Invoked after all child nodes are processed. + void PostVisitCall(const Call*, const Expr*, const SourcePosition*) override { + } + + // Invoked before all child nodes are processed. + void PreVisitComprehension(const Comprehension*, const Expr*, + const SourcePosition*) override {} + + // Invoked after all child nodes are processed. + void PostVisitComprehension(const Comprehension*, const Expr*, + const SourcePosition*) override {} + + // Invoked after each argument node processed. + // For Call arg_num is the index of the argument. + // For Comprehension arg_num is specified by ComprehensionArg. + // Expr is the call expression. + void PostVisitArg(int, const Expr*, const SourcePosition*) override {} + + // Invoked after target node processed. + void PostVisitTarget(const Expr*, const SourcePosition*) override {} + + // CreateList node handler + // Invoked after child nodes are processed. + void PostVisitCreateList(const CreateList*, const Expr*, + const SourcePosition*) override {} + + // CreateStruct node handler + // Invoked after child nodes are processed. + void PostVisitCreateStruct(const CreateStruct*, const Expr*, + const SourcePosition*) override {} +}; + +} // namespace internal +} // namespace ast +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ diff --git a/eval/public/source_position_native.cc b/eval/public/source_position_native.cc new file mode 100644 index 000000000..0e1281e1b --- /dev/null +++ b/eval/public/source_position_native.cc @@ -0,0 +1,66 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/source_position_native.h" + +namespace cel { +namespace ast { +namespace internal { + +namespace { + +std::pair GetLineAndLineOffset(const SourceInfo* source_info, + int32_t position) { + int line = 0; + int32_t line_offset = 0; + if (source_info != nullptr) { + for (const auto& curr_line_offset : source_info->line_offsets()) { + if (curr_line_offset > position) { + break; + } + line_offset = curr_line_offset; + line++; + } + } + if (line == 0) { + line++; + } + return std::pair(line, line_offset); +} + +} // namespace + +int32_t SourcePosition::line() const { + return GetLineAndLineOffset(source_info_, character_offset()).first; +} + +int32_t SourcePosition::column() const { + int32_t position = character_offset(); + std::pair line_and_offset = + GetLineAndLineOffset(source_info_, position); + return 1 + (position - line_and_offset.second); +} + +int32_t SourcePosition::character_offset() const { + if (source_info_ == nullptr) { + return 0; + } + auto position_it = source_info_->positions().find(expr_id_); + return position_it != source_info_->positions().end() ? position_it->second + : 0; +} + +} // namespace internal +} // namespace ast +} // namespace cel diff --git a/eval/public/source_position_native.h b/eval/public/source_position_native.h new file mode 100644 index 000000000..fcbba85f5 --- /dev/null +++ b/eval/public/source_position_native.h @@ -0,0 +1,62 @@ +/* + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ + +#include "base/ast.h" + +namespace cel { +namespace ast { +namespace internal { + +// Class representing the source position as well as line and column data for +// a given expression id. +class SourcePosition { + public: + // Constructor for a SourcePosition value. The source_info may be nullptr, + // in which case line, column, and character_offset will return 0. + SourcePosition(const int64_t expr_id, const SourceInfo* source_info) + : expr_id_(expr_id), source_info_(source_info) {} + + // Non-copyable + SourcePosition(const SourcePosition& other) = delete; + SourcePosition& operator=(const SourcePosition& other) = delete; + + virtual ~SourcePosition() {} + + // Return the 1-based source line number for the expression. + int32_t line() const; + + // Return the 1-based column offset within the source line for the + // expression. + int32_t column() const; + + // Return the 0-based character offset of the expression within source. + int32_t character_offset() const; + + private: + // The expression identifier. + const int64_t expr_id_; + // The source information reference generated during expression parsing. + const SourceInfo* source_info_; +}; + +} // namespace internal +} // namespace ast +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ diff --git a/eval/public/source_position_native_test.cc b/eval/public/source_position_native_test.cc new file mode 100644 index 000000000..792a79c80 --- /dev/null +++ b/eval/public/source_position_native_test.cc @@ -0,0 +1,108 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/source_position_native.h" + +#include "internal/testing.h" + +namespace cel { +namespace ast { +namespace internal { + +namespace { + +using testing::Eq; + +class SourcePositionTest : public testing::Test { + protected: + void SetUp() override { + // Simulate the expression positions : '\n\na\n&& b\n\n|| c' + // + // Within the ExprChecker, the line offset is the first character of the + // line rather than the newline character. + // + // The tests outputs are affected by leading newlines, but not trailing + // newlines, and the ExprChecker will actually always generate a trailing + // newline entry for EOF; however, this offset is not included in the test + // since there may be other parsers which generate newline information + // slightly differently. + source_info_.mutable_line_offsets().push_back(0); + source_info_.mutable_line_offsets().push_back(1); + source_info_.mutable_line_offsets().push_back(2); + (source_info_.mutable_positions())[1] = 2; + source_info_.mutable_line_offsets().push_back(4); + (source_info_.mutable_positions())[2] = 4; + (source_info_.mutable_positions())[3] = 7; + source_info_.mutable_line_offsets().push_back(9); + source_info_.mutable_line_offsets().push_back(10); + (source_info_.mutable_positions())[4] = 10; + (source_info_.mutable_positions())[5] = 13; + } + + SourceInfo source_info_; +}; + +TEST_F(SourcePositionTest, TestNullSourceInfo) { + SourcePosition position(3, nullptr); + EXPECT_THAT(position.character_offset(), Eq(0)); + EXPECT_THAT(position.line(), Eq(1)); + EXPECT_THAT(position.column(), Eq(1)); +} + +TEST_F(SourcePositionTest, TestNoNewlines) { + source_info_.mutable_line_offsets().clear(); + SourcePosition position(3, &source_info_); + EXPECT_THAT(position.character_offset(), Eq(7)); + EXPECT_THAT(position.line(), Eq(1)); + EXPECT_THAT(position.column(), Eq(8)); +} + +TEST_F(SourcePositionTest, TestPosition) { + SourcePosition position(3, &source_info_); + EXPECT_THAT(position.character_offset(), Eq(7)); +} + +TEST_F(SourcePositionTest, TestLine) { + SourcePosition position1(1, &source_info_); + EXPECT_THAT(position1.line(), Eq(3)); + + SourcePosition position2(2, &source_info_); + EXPECT_THAT(position2.line(), Eq(4)); + + SourcePosition position3(3, &source_info_); + EXPECT_THAT(position3.line(), Eq(4)); + + SourcePosition position4(5, &source_info_); + EXPECT_THAT(position4.line(), Eq(6)); +} + +TEST_F(SourcePositionTest, TestColumn) { + SourcePosition position1(1, &source_info_); + EXPECT_THAT(position1.column(), Eq(1)); + + SourcePosition position2(2, &source_info_); + EXPECT_THAT(position2.column(), Eq(1)); + + SourcePosition position3(3, &source_info_); + EXPECT_THAT(position3.column(), Eq(4)); + + SourcePosition position4(5, &source_info_); + EXPECT_THAT(position4.column(), Eq(4)); +} + +} // namespace + +} // namespace internal +} // namespace ast +} // namespace cel From 9b0b4de848b899a69617ca6bc347e6de11d6b11c Mon Sep 17 00:00:00 2001 From: kuat Date: Wed, 20 Jul 2022 00:06:27 +0000 Subject: [PATCH 014/303] Fix OSS build warnings. PiperOrigin-RevId: 462010326 --- base/handle.h | 2 +- base/internal/data.h | 12 ++---------- base/type.h | 3 +-- base/type_provider.h | 2 +- eval/public/structs/legacy_type_adapter.h | 4 ++-- 5 files changed, 7 insertions(+), 16 deletions(-) diff --git a/base/handle.h b/base/handle.h index 460b2797a..969eb0106 100644 --- a/base/handle.h +++ b/base/handle.h @@ -188,7 +188,7 @@ class Persistent final : private base_internal::HandlePolicy { friend struct base_internal::HandleFactory; template - explicit Persistent(absl::in_place_t in_place, Args&&... args) + explicit Persistent(absl::in_place_t, Args&&... args) : impl_(std::forward(args)...) {} Handle impl_; diff --git a/base/internal/data.h b/base/internal/data.h index f2d8aa905..f0470bd1c 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -162,15 +162,6 @@ class HeapData /* : public Data */ { 0; }; -inline constexpr size_t HeapDataMetadataAndReferenceCountOffset() { - return offsetof(HeapData, metadata_and_reference_count_); -} - -static_assert(HeapDataMetadataAndReferenceCountOffset() == sizeof(uintptr_t), - "Expected vptr to be at offset 0"); -static_assert(sizeof(HeapData) == sizeof(uintptr_t) * 2, - "Unexpected class size"); - // Provides introspection for `Data`. class Metadata final { public: @@ -302,7 +293,8 @@ union alignas(Align) AnyDataStorage final { // dereference our stored pointers as it may have already been deleted. Thus we // need to know if it was arena allocated without dereferencing the pointer. template -struct AnyData { +class AnyData { + public: static_assert(Size >= sizeof(uintptr_t), "Size must be at least sizeof(uintptr_t)"); static_assert(Align >= alignof(uintptr_t), diff --git a/base/type.h b/base/type.h index df43b9278..23057143c 100644 --- a/base/type.h +++ b/base/type.h @@ -106,8 +106,7 @@ class PersistentTypeHandle final { PersistentTypeHandle() = default; template - explicit PersistentTypeHandle(absl::in_place_type_t in_place_type, - Args&&... args) { + explicit PersistentTypeHandle(absl::in_place_type_t, Args&&... args) { data_.ConstructInline(std::forward(args)...); } diff --git a/base/type_provider.h b/base/type_provider.h index cde5befa8..3e5d25c2b 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -49,7 +49,7 @@ class TypeProvider { // // An empty handle is returned if the provider cannot find the requested type. virtual absl::StatusOr> ProvideType( - TypeFactory& type_factory, absl::string_view name) const { + TypeFactory&, absl::string_view) const { return absl::UnimplementedError("ProvideType is not yet implemented"); } }; diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index a7659a7bb..1ddc9536e 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -85,8 +85,8 @@ class LegacyTypeAccessApis { // return whether they are equal. // To conform to the CEL spec, message equality should follow the behavior of // MessageDifferencer::Equals. - virtual bool IsEqualTo(const CelValue::MessageWrapper& instance, - const CelValue::MessageWrapper& other_instance) const { + virtual bool IsEqualTo(const CelValue::MessageWrapper&, + const CelValue::MessageWrapper&) const { return false; } }; From 557e813b0bec926dd3f8155d74f5a1c3618fda87 Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 21 Jul 2022 01:00:28 +0000 Subject: [PATCH 015/303] GCC and MSVC fixes. PiperOrigin-RevId: 462270835 --- base/internal/data.h | 5 +++++ base/memory_manager.cc | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/base/internal/data.h b/base/internal/data.h index f0470bd1c..957b72edb 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -97,6 +97,9 @@ enum class DataLocality { // at least `sizeof(void*)`. class Data {}; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wattributes" + // Empty base class indicating class must be stored directly in the handle and // not allocated separately on the heap. // @@ -162,6 +165,8 @@ class HeapData /* : public Data */ { 0; }; +#pragma GCC diagnostic pop + // Provides introspection for `Data`. class Metadata final { public: diff --git a/base/memory_manager.cc b/base/memory_manager.cc index 9c2bd2341..1b7b3550a 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -117,13 +117,13 @@ std::optional ArenaBlockAllocate(size_t size, pointer = VirtualAlloc(hint, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE); if (ABSL_PREDICT_FALSE(pointer == nullptr)) { if (hint == nullptr) { - return absl::nullopt; + return std::nullopt; } // Try again, without the hint. pointer = VirtualAlloc(nullptr, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE); if (pointer == nullptr) { - return absl::nullopt; + return std::nullopt; } } #endif From 280925b712bd381e11b265c8585f15d4752928ad Mon Sep 17 00:00:00 2001 From: Kuat Yessenov Date: Wed, 31 Aug 2022 20:38:45 -0700 Subject: [PATCH 016/303] build: fix antlr tag Signed-off-by: Kuat Yessenov --- bazel/deps.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 669f3b514..44be7be8d 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -86,7 +86,7 @@ cc_library( """, sha256 = "a320568b738e42735946bebc5d9d333170e14a251c5734e8b852ad1502efa8a2", strip_prefix = "antlr4-" + ANTLR4_VERSION, - urls = ["https://github.com/antlr/antlr4/archive/v" + ANTLR4_VERSION + ".tar.gz"], + urls = ["https://github.com/antlr/antlr4/archive/" + ANTLR4_VERSION + ".tar.gz"], ) http_jar( name = "antlr4_jar", From 0b2e65c4675c04bf7fea484d047b20d36f001d56 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 22 Jul 2022 20:30:48 +0000 Subject: [PATCH 017/303] Update reference resolve step to warn if reference map uses ID 0. The parsers should not use ID 0, and it causes problems for custom ASTs (which tend to leave the ID unset, default 0). PiperOrigin-RevId: 462691976 --- bazel/deps.bzl | 2 +- eval/compiler/constant_folding.cc | 4 +- eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/qualified_reference_resolver.cc | 8 ++++ eval/compiler/qualified_reference_resolver.h | 1 - .../qualified_reference_resolver_test.cc | 46 ++++++++++++++++++- eval/eval/evaluator_core.cc | 4 +- eval/eval/evaluator_stack.h | 14 +++--- eval/eval/select_step.cc | 2 +- eval/public/ast_rewrite.cc | 2 +- eval/public/ast_traverse.cc | 2 +- eval/public/builtin_func_registrar.cc | 2 +- eval/public/cel_expr_builder_factory.cc | 4 +- eval/public/cel_value.h | 4 +- .../portable_cel_expr_builder_factory.cc | 2 +- eval/public/structs/field_access_impl.cc | 2 +- 16 files changed, 76 insertions(+), 25 deletions(-) diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 44be7be8d..669f3b514 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -86,7 +86,7 @@ cc_library( """, sha256 = "a320568b738e42735946bebc5d9d333170e14a251c5734e8b852ad1502efa8a2", strip_prefix = "antlr4-" + ANTLR4_VERSION, - urls = ["https://github.com/antlr/antlr4/archive/" + ANTLR4_VERSION + ".tar.gz"], + urls = ["https://github.com/antlr/antlr4/archive/v" + ANTLR4_VERSION + ".tar.gz"], ) http_jar( name = "antlr4_jar", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 115467346..97d810589 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -169,7 +169,7 @@ class ConstantFoldingTransform { Transform(entry.map_key(), new_entry->mutable_map_key()); break; default: - GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " << entry.key_kind_case(); + LOG(ERROR) << "Unsupported Entry kind: " << entry.key_kind_case(); break; } Transform(entry.value(), new_entry->mutable_value()); @@ -192,7 +192,7 @@ class ConstantFoldingTransform { return false; } default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << expr.expr_kind_case(); + LOG(ERROR) << "Unsupported Expr kind: " << expr.expr_kind_case(); return false; } } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 999d03ad8..e57fc959c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -193,7 +193,7 @@ class FlatExprVisitor : public AstVisitor { enable_wrapper_type_null_unboxing_(enable_wrapper_type_null_unboxing), builder_warnings_(warnings), iter_variable_names_(iter_variable_names) { - GOOGLE_CHECK(iter_variable_names_); + DCHECK(iter_variable_names_); } void PreVisitExpr(const Expr* expr, const SourcePosition*) override { diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 00e137438..e93d28c1d 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -255,10 +255,18 @@ class ReferenceResolver : public AstRewriterBase { if (reference_map_ == nullptr) { return nullptr; } + auto iter = reference_map_->find(expr_id); if (iter == reference_map_->end()) { return nullptr; } + if (expr_id == 0) { + warnings_ + .AddWarning(absl::InvalidArgumentError( + "reference map entries for expression id 0 are not supported")) + .IgnoreError(); + return nullptr; + } return &iter->second; } diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index eb142031e..fd6c6199b 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -6,7 +6,6 @@ #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/map.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/compiler/resolver.h" diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 9ae1170dd..642622901 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -23,13 +23,14 @@ namespace { using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::Reference; using ::google::api::expr::v1alpha1::SourceInfo; +using ::google::api::expr::testutil::EqualsProto; +using testing::Contains; using testing::ElementsAre; using testing::Eq; using testing::IsEmpty; using testing::UnorderedElementsAre; using cel::internal::IsOkAndHolds; using cel::internal::StatusIs; -using testutil::EqualsProto; // foo.bar.var1 && bar.foo.var2 constexpr char kExpr[] = R"( @@ -837,6 +838,49 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { })pb")); } +TEST(ResolveReferences, ReferenceToId0Warns) { + // ID 0 is unsupported since it is not normally used by parsers and is + // ambiguous as an intentional ID or default for unset field. + Expr expr = ParseTestProto(R"pb( + id: 0 + select_expr { + operand { + id: 1 + ident_expr { name: "pkg" } + } + field: "var" + })pb"); + + SourceInfo source_info; + + google::protobuf::Map reference_map; + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + reference_map[0].set_name("pkg.var"); + BuilderWarnings warnings; + + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 0 + select_expr { + operand { + id: 1 + ident_expr { name: "pkg" } + } + field: "var" + })pb")); + + EXPECT_THAT( + warnings.warnings(), + Contains(StatusIs( + absl::StatusCode::kInvalidArgument, + "reference map entries for expression id 0 are not supported"))); +} } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 27904ce45..1175f603f 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -40,7 +40,7 @@ const ExpressionStep* ExecutionFrame::Next() { if (pc_ < end_pos) return execution_path_[pc_++].get(); if (pc_ > end_pos) { - GOOGLE_LOG(ERROR) << "Attempting to step beyond the end of execution path."; + LOG(ERROR) << "Attempting to step beyond the end of execution path."; } return nullptr; } @@ -176,7 +176,7 @@ absl::StatusOr CelExpressionFlatImpl::Trace( } if (stack->empty()) { - GOOGLE_LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " + LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " "Try to disable short-circuiting."; continue; } diff --git a/eval/eval/evaluator_stack.h b/eval/eval/evaluator_stack.h index 331a999ec..8e67bfa39 100644 --- a/eval/eval/evaluator_stack.h +++ b/eval/eval/evaluator_stack.h @@ -42,7 +42,7 @@ class EvaluatorStack { // Please note that calls to Push may invalidate returned Span object. absl::Span GetSpan(size_t size) const { if (!HasEnough(size)) { - GOOGLE_LOG(ERROR) << "Requested span size (" << size + LOG(ERROR) << "Requested span size (" << size << ") exceeds current stack size: " << current_size_; } return absl::Span(stack_.data() + current_size_ - size, @@ -61,7 +61,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. const CelValue& Peek() const { if (empty()) { - GOOGLE_LOG(ERROR) << "Peeking on empty EvaluatorStack"; + LOG(ERROR) << "Peeking on empty EvaluatorStack"; } return stack_[current_size_ - 1]; } @@ -70,7 +70,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. const AttributeTrail& PeekAttribute() const { if (empty()) { - GOOGLE_LOG(ERROR) << "Peeking on empty EvaluatorStack"; + LOG(ERROR) << "Peeking on empty EvaluatorStack"; } return attribute_stack_[current_size_ - 1]; } @@ -79,7 +79,7 @@ class EvaluatorStack { // Checking that stack has enough elements is caller's responsibility. void Pop(size_t size) { if (!HasEnough(size)) { - GOOGLE_LOG(ERROR) << "Trying to pop more elements (" << size + LOG(ERROR) << "Trying to pop more elements (" << size << ") than the current stack size: " << current_size_; } current_size_ -= size; @@ -90,7 +90,7 @@ class EvaluatorStack { void Push(const CelValue& value, AttributeTrail attribute) { if (current_size_ >= stack_.size()) { - GOOGLE_LOG(ERROR) << "No room to push more elements on to EvaluatorStack"; + LOG(ERROR) << "No room to push more elements on to EvaluatorStack"; } stack_[current_size_] = value; attribute_stack_[current_size_] = attribute; @@ -107,7 +107,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. void PopAndPush(const CelValue& value, AttributeTrail attribute) { if (empty()) { - GOOGLE_LOG(ERROR) << "Cannot PopAndPush on empty stack."; + LOG(ERROR) << "Cannot PopAndPush on empty stack."; } stack_[current_size_ - 1] = value; attribute_stack_[current_size_ - 1] = attribute; @@ -124,7 +124,7 @@ class EvaluatorStack { // Returns true if any values are successfully converted. bool CoerceNullValues(size_t size) { if (!HasEnough(size)) { - GOOGLE_LOG(ERROR) << "Trying to coerce more elements (" << size + LOG(ERROR) << "Trying to coerce more elements (" << size << ") than the current stack size: " << current_size_; } bool updated = false; diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 0e4d300c7..e3be86bb8 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -92,7 +92,7 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, } // Invariant broken (an invalid CEL Attribute shouldn't match anything). // Log and return a CelError. - GOOGLE_LOG(ERROR) + LOG(ERROR) << "Invalid attribute pattern matched select path: " << attribute_string.status().ToString(); // NOLINT: OSS compatibility return CreateErrorValue(frame->memory_manager(), attribute_string.status()); diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index f8264ef43..132fdaa72 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -192,7 +192,7 @@ struct PostVisitor { &position); break; default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index 02494de3c..a20605ed8 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -184,7 +184,7 @@ struct PostVisitor { &position); break; default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 613522a4d..581d9e1fe 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -1486,7 +1486,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, auto regex_matches = [max_size = options.regex_max_program_size]( Arena* arena, CelValue::StringHolder target, CelValue::StringHolder regex) -> CelValue { - RE2 re2(regex.value().data()); + RE2 re2(re2::StringPiece(regex.value().data(), regex.value().size())); if (max_size > 0 && re2.ProgramSize() > max_size) { return CreateErrorValue(arena, "exceeded RE2 max program size", absl::StatusCode::kInvalidArgument); diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index c862826c2..679d60e38 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -39,12 +39,12 @@ std::unique_ptr CreateCelExpressionBuilder( google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { if (descriptor_pool == nullptr) { - GOOGLE_LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " + LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " "CreateCelExpressionBuilder"; return nullptr; } if (auto s = ValidateStandardMessageTypes(*descriptor_pool); !s.ok()) { - GOOGLE_LOG(WARNING) << "Failed to validate standard message types: " + LOG(WARNING) << "Failed to validate standard message types: " << s.ToString(); // NOLINT: OSS compatibility return nullptr; } diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 727d9af44..173d59db7 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -477,7 +477,7 @@ class CelValue { // Crashes with a null pointer error. static void CrashNullPointer(Type type) ABSL_ATTRIBUTE_COLD { - GOOGLE_LOG(FATAL) << "Null pointer supplied for " << TypeName(type); // Crash ok + LOG(FATAL) << "Null pointer supplied for " << TypeName(type); // Crash ok } // Null pointer checker for pointer-based types. @@ -490,7 +490,7 @@ class CelValue { // Crashes with a type mismatch error. static void CrashTypeMismatch(Type requested_type, Type actual_type) ABSL_ATTRIBUTE_COLD { - GOOGLE_LOG(FATAL) << "Type mismatch" // Crash ok + LOG(FATAL) << "Type mismatch" // Crash ok << ": expected " << TypeName(requested_type) // Crash ok << ", encountered " << TypeName(actual_type); // Crash ok } diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 025982ff9..e5eb6f608 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -30,7 +30,7 @@ std::unique_ptr CreatePortableExprBuilder( std::unique_ptr type_provider, const InterpreterOptions& options) { if (type_provider == nullptr) { - GOOGLE_LOG(ERROR) << "Cannot pass nullptr as type_provider to " + LOG(ERROR) << "Cannot pass nullptr as type_provider to " "CreatePortableExprBuilder"; return nullptr; } diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 9f8faf7ba..788a47666 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -586,7 +586,7 @@ class ScalarFieldSetter : public FieldSetter { bool SetMessage(const Message* value) const { if (!value) { - GOOGLE_LOG(ERROR) << "Message is NULL"; + LOG(ERROR) << "Message is NULL"; return true; } From 5962c6483269ae08c0c3e18e865c90e800e1c0a1 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 22 Jul 2022 21:13:08 +0000 Subject: [PATCH 018/303] Update comprehension memory exhaustion vulnerability check to correctly count references in result step of nested comprehensions. Previously, references would be ignored in the result step if the outer accumulator was shadowed by an iter_var (even though that variable may not be in scope for the result step). PiperOrigin-RevId: 462701024 --- eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 37 ++++++++++--- .../flat_expr_builder_comprehensions_test.cc | 54 +++++++++++++++++++ 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e7ee05866..772754345 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -104,6 +104,7 @@ cc_test( "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index e57fc959c..438c186ca 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -802,6 +802,14 @@ const Expr* CurrentValueDummy() { // writers, only by macro authors. However, a hand-rolled AST makes it possible // to misuse the accumulation variable. // +// Limitations: +// - This check only covers standard operators and functions. +// Extension functions may cause the same issue if they allocate an amount of +// memory that is dependent on the size of the inputs. +// +// - This check is not exhaustive. There may be ways to construct an AST to +// trigger exponential memory growth not captured by this check. +// // The algorithm for reference counting is as follows: // // * Calls - If the call is a concatenation operator, sum the number of places @@ -869,17 +877,30 @@ int ComprehensionAccumulationReferences(const Expr& expr, const auto& comprehension = expr.comprehension_expr(); absl::string_view accu_var = comprehension.accu_var(); absl::string_view iter_var = comprehension.iter_var(); - // Tne accumulation or iteration variable shadows the var_name and so will - // not manipulate the target var_name in a nested comprhension scope. - if (accu_var == var_name || iter_var == var_name) { - return 0; + + int result_references = 0; + int loop_step_references = 0; + + // The accumulation or iteration variable shadows the var_name and so will + // not manipulate the target var_name in a nested comprehension scope. + if (accu_var != var_name && iter_var != var_name) { + loop_step_references = ComprehensionAccumulationReferences( + comprehension.loop_step(), var_name); } + + // Accumulator variable (but not necessarily iter var) can shadow an + // outer accumulator variable in the result sub-expression. + if (accu_var != var_name) { + result_references = ComprehensionAccumulationReferences( + comprehension.result(), var_name); + } + // Count the number of times the accumulator var_name within the loop_step // or the nested comprehension result. - const Expr& loop_step = comprehension.loop_step(); - const Expr& result = comprehension.result(); - return std::max(ComprehensionAccumulationReferences(loop_step, var_name), - ComprehensionAccumulationReferences(result, var_name)); + // + // This doesn't cover cases where the inner accumulator accumulates the + // outer accumulator then is returned in the inner comprehension result. + return std::max(loop_step_references, result_references); } case Expr::kListExpr: { // Count the number of times the accumulator var_name appears within a diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 52b1276ed..6111c9f84 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -14,8 +14,12 @@ * limitations under the License. */ +#include +#include + #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/strings/str_split.h" @@ -28,6 +32,7 @@ #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" @@ -376,6 +381,55 @@ TEST(FlatExprBuilderComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } +TEST(FlatExprBuilderComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepVulernabilityResult) { + CheckedExpr expr; + // The nested comprehension performs an unsafe concatenation on the parent + // accumulator. + google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + comprehension_expr { + iter_var: "outer_iter" + iter_range { ident_expr { name: "input_list" } } + accu_var: "outer_accu" + accu_init { ident_expr { name: "input_list" } } + loop_condition { + id: 3 + const_expr { bool_value: true } + } + loop_step { + comprehension_expr { + # the iter_var shadows the outer accumulator on the loop step + # but not the result step. + iter_var: "outer_accu" + iter_range { list_expr {} } + accu_var: "inner_accu" + accu_init { list_expr {} } + loop_condition { const_expr { bool_value: true } } + loop_step { list_expr {} } + result { + call_expr { + function: "_+_" + args { ident_expr { name: "outer_accu" } } + args { ident_expr { name: "outer_accu" } } + } + } + } + } + result { list_expr {} } + } + } + )pb", + &expr); + FlatExprBuilder builder; + builder.set_enable_comprehension_vulnerability_check(true); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + EXPECT_THAT(builder.CreateExpression(&expr).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("memory exhaustion vulnerability"))); +} + } // namespace } // namespace google::api::expr::runtime From a13c4a861ed095e0055c5cdae124b0d2638cd857 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 22 Jul 2022 21:14:21 +0000 Subject: [PATCH 019/303] Remove deprecated option for partial string match. PiperOrigin-RevId: 462701292 --- eval/public/cel_options.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 1311e5cbe..3b6fba93d 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -61,9 +61,6 @@ struct InterpreterOptions { // resulting value is known from the left-hand side. bool short_circuiting = true; - // DEPRECATED. This option has no effect. - bool partial_string_match = true; - // Enable constant folding during the expression creation. If enabled, // an arena must be provided for constant generation. // Note that expression tracing applies a modified expression if this option From 8e5c9cc606c3a3696ebea7900083ac44da64478a Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 28 Jul 2022 22:53:08 +0000 Subject: [PATCH 020/303] Add skeletons for `UnknownType` and `UnknownValue` ahead of actual implementation PiperOrigin-RevId: 463946177 --- base/BUILD | 1 - base/kind.cc | 2 ++ base/kind.h | 1 + base/kind_test.cc | 1 + base/type.cc | 9 ++++++ base/type.h | 5 ++++ base/type_factory.cc | 4 +++ base/type_factory.h | 3 ++ base/type_test.cc | 44 +++++++++++++++++++++++++++-- base/types/unknown_type.cc | 41 +++++++++++++++++++++++++++ base/types/unknown_type.h | 54 +++++++++++++++++++++++++++++++++++ base/value.cc | 13 +++++++++ base/value.h | 2 ++ base/value_factory.cc | 6 ++++ base/value_factory.h | 4 +++ base/value_test.cc | 21 ++++++++++++++ base/values/unknown_value.cc | 43 ++++++++++++++++++++++++++++ base/values/unknown_value.h | 55 ++++++++++++++++++++++++++++++++++++ 18 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 base/types/unknown_type.cc create mode 100644 base/types/unknown_type.h create mode 100644 base/values/unknown_value.cc create mode 100644 base/values/unknown_value.h diff --git a/base/BUILD b/base/BUILD index 9a4f65ba1..c012c81c3 100644 --- a/base/BUILD +++ b/base/BUILD @@ -204,7 +204,6 @@ cc_test( ":value", "//base/internal:memory_manager_testing", "//internal:testing", - "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", ], diff --git a/base/kind.cc b/base/kind.cc index 8eccd110f..60b0fdad2 100644 --- a/base/kind.cc +++ b/base/kind.cc @@ -50,6 +50,8 @@ absl::string_view KindToString(Kind kind) { return "map"; case Kind::kStruct: return "struct"; + case Kind::kUnknown: + return "*unknown*"; default: return "*error*"; } diff --git a/base/kind.h b/base/kind.h index a32c432f8..66ecb2016 100644 --- a/base/kind.h +++ b/base/kind.h @@ -40,6 +40,7 @@ enum class Kind : uint8_t { kList, kMap, kStruct, + kUnknown, // INTERNAL: Do not exceed 127. Implementation details rely on the fact that // we can store `Kind` using 7 bits. diff --git a/base/kind_test.cc b/base/kind_test.cc index 22050cf9c..fbb40e866 100644 --- a/base/kind_test.cc +++ b/base/kind_test.cc @@ -39,6 +39,7 @@ TEST(Kind, ToString) { EXPECT_EQ(KindToString(Kind::kList), "list"); EXPECT_EQ(KindToString(Kind::kMap), "map"); EXPECT_EQ(KindToString(Kind::kStruct), "struct"); + EXPECT_EQ(KindToString(Kind::kUnknown), "*unknown*"); EXPECT_EQ(KindToString(static_cast(std::numeric_limits::max())), "*error*"); } diff --git a/base/type.cc b/base/type.cc index 164a30b84..9cc9cd71f 100644 --- a/base/type.cc +++ b/base/type.cc @@ -38,6 +38,7 @@ #include "base/types/timestamp_type.h" #include "base/types/type_type.h" #include "base/types/uint_type.h" +#include "base/types/unknown_type.h" #include "internal/unreachable.h" namespace cel { @@ -80,6 +81,8 @@ absl::string_view Type::name() const { return static_cast(this)->name(); case Kind::kStruct: return static_cast(this)->name(); + case Kind::kUnknown: + return static_cast(this)->name(); } } @@ -119,6 +122,8 @@ std::string Type::DebugString() const { return static_cast(this)->DebugString(); case Kind::kStruct: return static_cast(this)->DebugString(); + case Kind::kUnknown: + return static_cast(this)->DebugString(); } } @@ -161,6 +166,8 @@ bool Type::Equals(const Type& other) const { return static_cast(this)->Equals(other); case Kind::kStruct: return static_cast(this)->Equals(other); + case Kind::kUnknown: + return static_cast(this)->Equals(other); } } @@ -202,6 +209,8 @@ void Type::HashValue(absl::HashState state) const { return static_cast(this)->HashValue(std::move(state)); case Kind::kStruct: return static_cast(this)->HashValue(std::move(state)); + case Kind::kUnknown: + return static_cast(this)->HashValue(std::move(state)); } } diff --git a/base/type.h b/base/type.h index 23057143c..867c19e8d 100644 --- a/base/type.h +++ b/base/type.h @@ -261,6 +261,11 @@ struct SimpleTypeName { static constexpr absl::string_view value = "type"; }; +template <> +struct SimpleTypeName { + static constexpr absl::string_view value = "*unknown*"; +}; + template class SimpleType : public Type, public InlineData { public: diff --git a/base/type_factory.cc b/base/type_factory.cc index 2b257e517..66b1eb8f2 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -75,6 +75,10 @@ Persistent TypeFactory::GetTypeType() { return TypeType::Get(); } +Persistent TypeFactory::GetUnknownType() { + return UnknownType::Get(); +} + absl::StatusOr> TypeFactory::CreateListType( const Persistent& element) { absl::MutexLock lock(&list_types_mutex_); diff --git a/base/type_factory.h b/base/type_factory.h index b96f20b1f..933929bcb 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -41,6 +41,7 @@ #include "base/types/timestamp_type.h" #include "base/types/type_type.h" #include "base/types/uint_type.h" +#include "base/types/unknown_type.h" namespace cel { @@ -112,6 +113,8 @@ class TypeFactory final { Persistent GetTypeType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + Persistent GetUnknownType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager() const { return memory_manager_; } private: diff --git a/base/type_test.cc b/base/type_test.cc index e057bd299..d77b7720b 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -14,10 +14,10 @@ #include "base/type.h" +#include #include #include -#include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/status/status.h" #include "base/handle.h" @@ -33,7 +33,6 @@ namespace cel { namespace { -using testing::SizeIs; using cel::internal::StatusIs; enum class TestEnum { @@ -260,6 +259,7 @@ TEST_P(TypeTest, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST_P(TypeTest, Error) { @@ -282,6 +282,7 @@ TEST_P(TypeTest, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST_P(TypeTest, Dyn) { @@ -304,6 +305,7 @@ TEST_P(TypeTest, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST_P(TypeTest, Any) { @@ -326,6 +328,7 @@ TEST_P(TypeTest, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST_P(TypeTest, Bool) { @@ -348,6 +351,7 @@ TEST_P(TypeTest, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST_P(TypeTest, Int) { @@ -370,6 +374,7 @@ TEST_P(TypeTest, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST_P(TypeTest, Uint) { @@ -392,6 +397,7 @@ TEST_P(TypeTest, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST_P(TypeTest, Double) { @@ -414,6 +420,7 @@ TEST_P(TypeTest, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST_P(TypeTest, String) { @@ -436,6 +443,7 @@ TEST_P(TypeTest, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST_P(TypeTest, Bytes) { @@ -458,6 +466,7 @@ TEST_P(TypeTest, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST_P(TypeTest, Duration) { @@ -480,6 +489,7 @@ TEST_P(TypeTest, Duration) { EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST_P(TypeTest, Timestamp) { @@ -503,6 +513,7 @@ TEST_P(TypeTest, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); } TEST_P(TypeTest, Enum) { @@ -528,6 +539,7 @@ TEST_P(TypeTest, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); } TEST_P(TypeTest, Struct) { @@ -555,6 +567,7 @@ TEST_P(TypeTest, Struct) { EXPECT_FALSE(struct_type.Is()); EXPECT_FALSE(struct_type.Is()); EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); } TEST_P(TypeTest, List) { @@ -582,6 +595,7 @@ TEST_P(TypeTest, List) { EXPECT_TRUE(list_type.Is()); EXPECT_FALSE(list_type.Is()); EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); } TEST_P(TypeTest, Map) { @@ -615,6 +629,7 @@ TEST_P(TypeTest, Map) { EXPECT_FALSE(map_type.Is()); EXPECT_TRUE(map_type.Is()); EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); } TEST_P(TypeTest, TypeType) { @@ -637,6 +652,30 @@ TEST_P(TypeTest, TypeType) { EXPECT_FALSE(type_factory.GetTypeType().Is()); EXPECT_FALSE(type_factory.GetTypeType().Is()); EXPECT_TRUE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); +} + +TEST_P(TypeTest, UnknownType) { + TypeFactory type_factory(memory_manager()); + EXPECT_EQ(type_factory.GetUnknownType()->kind(), Kind::kUnknown); + EXPECT_EQ(type_factory.GetUnknownType()->name(), "*unknown*"); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_FALSE(type_factory.GetUnknownType().Is()); + EXPECT_TRUE(type_factory.GetUnknownType().Is()); } using EnumTypeTest = TypeTest; @@ -875,6 +914,7 @@ TEST_P(TypeTest, SupportsAbslHash) { Persistent(Must(type_factory.CreateMapType( type_factory.GetStringType(), type_factory.GetBoolType()))), Persistent(type_factory.GetTypeType()), + Persistent(type_factory.GetUnknownType()), })); } diff --git a/base/types/unknown_type.cc b/base/types/unknown_type.cc new file mode 100644 index 000000000..baec35e42 --- /dev/null +++ b/base/types/unknown_type.cc @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/types/unknown_type.h" + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" + +namespace cel { + +CEL_INTERNAL_TYPE_IMPL(UnknownType); + +namespace { + +ABSL_CONST_INIT absl::once_flag instance_once; +alignas(Persistent) char instance_storage[sizeof( + Persistent)]; + +} // namespace + +const Persistent& UnknownType::Get() { + absl::call_once(instance_once, []() { + base_internal::PersistentHandleFactory::MakeAt< + UnknownType>(&instance_storage[0]); + }); + return *reinterpret_cast*>( + &instance_storage[0]); +} + +} // namespace cel diff --git a/base/types/unknown_type.h b/base/types/unknown_type.h new file mode 100644 index 000000000..9979a89f2 --- /dev/null +++ b/base/types/unknown_type.h @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPES_UNKNOWN_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPES_UNKNOWN_TYPE_H_ + +#include "base/kind.h" +#include "base/type.h" + +namespace cel { + +class UnknownValue; + +class UnknownType final : public base_internal::SimpleType { + private: + using Base = base_internal::SimpleType; + + public: + using Base::kKind; + + using Base::kName; + + using Base::Is; + + using Base::kind; + + using Base::name; + + using Base::DebugString; + + using Base::HashValue; + + using Base::Equals; + + private: + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(UnknownType, UnknownValue); +}; + +CEL_INTERNAL_SIMPLE_TYPE_STANDALONES(UnknownType); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPES_UNKNOWN_TYPE_H_ diff --git a/base/value.cc b/base/value.cc index 593a51c3d..4a29fbdfe 100644 --- a/base/value.cc +++ b/base/value.cc @@ -34,6 +34,7 @@ #include "base/values/timestamp_value.h" #include "base/values/type_value.h" #include "base/values/uint_value.h" +#include "base/values/unknown_value.h" #include "internal/unreachable.h" namespace cel { @@ -72,6 +73,8 @@ const Persistent& Value::type() const { return static_cast(this)->type().As(); case Kind::kStruct: return static_cast(this)->type().As(); + case Kind::kUnknown: + return static_cast(this)->type().As(); default: internal::unreachable(); } @@ -109,6 +112,8 @@ std::string Value::DebugString() const { return static_cast(this)->DebugString(); case Kind::kStruct: return static_cast(this)->DebugString(); + case Kind::kUnknown: + return static_cast(this)->DebugString(); default: internal::unreachable(); } @@ -148,6 +153,9 @@ void Value::HashValue(absl::HashState state) const { return static_cast(this)->HashValue(std::move(state)); case Kind::kStruct: return static_cast(this)->HashValue(std::move(state)); + case Kind::kUnknown: + return static_cast(this)->HashValue( + std::move(state)); default: internal::unreachable(); } @@ -188,6 +196,8 @@ bool Value::Equals(const Value& other) const { return static_cast(this)->Equals(other); case Kind::kStruct: return static_cast(this)->Equals(other); + case Kind::kUnknown: + return static_cast(this)->Equals(other); default: internal::unreachable(); } @@ -353,6 +363,9 @@ void PersistentValueHandle::Delete() const { case Kind::kBytes: delete static_cast(static_cast(data_.get())); break; + case Kind::kUnknown: + delete static_cast(static_cast(data_.get())); + break; default: internal::unreachable(); } diff --git a/base/value.h b/base/value.h index 35154f514..bf911fc36 100644 --- a/base/value.h +++ b/base/value.h @@ -45,6 +45,7 @@ class StructValue; class ListValue; class MapValue; class TypeValue; +class UnknownValue; class ValueFactory; // A representation of a CEL value that enables reflection and introspection of @@ -76,6 +77,7 @@ class Value : public base_internal::Data { friend class ListValue; friend class MapValue; friend class TypeValue; + friend class UnknownValue; friend class base_internal::PersistentValueHandle; template friend class base_internal::SimpleValue; diff --git a/base/value_factory.cc b/base/value_factory.cc index 4f5025189..b17b1d534 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -14,6 +14,7 @@ #include "base/value_factory.h" +#include #include #include @@ -244,6 +245,11 @@ Persistent ValueFactory::CreateTypeValue( return PersistentHandleFactory::Make(value); } +Persistent ValueFactory::CreateUnknownValue() { + return PersistentHandleFactory::Make( + memory_manager()); +} + absl::StatusOr> ValueFactory::CreateBytesValueFromView(absl::string_view value) { return PersistentHandleFactory::Make< diff --git a/base/value_factory.h b/base/value_factory.h index 7e1c6fc86..cfb2829a3 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -46,6 +46,7 @@ #include "base/values/timestamp_value.h" #include "base/values/type_value.h" #include "base/values/uint_value.h" +#include "base/values/unknown_value.h" namespace cel { @@ -207,6 +208,9 @@ class ValueFactory final { Persistent CreateTypeValue( const Persistent& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + Persistent CreateUnknownValue() + ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager() const { return type_manager().memory_manager(); } diff --git a/base/value_test.cc b/base/value_test.cc index 89846faea..5e43251fe 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -19,9 +19,11 @@ #include #include #include +#include #include #include #include +#include #include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" @@ -654,6 +656,11 @@ INSTANTIATE_TEST_SUITE_P( ValueFactory& value_factory) -> Persistent { return value_factory.CreateTypeValue(type_factory.GetNullType()); }}, + {"Unknown", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateUnknownValue(); + }}, })), [](const testing::TestParamInfo< std::tuple()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, value_factory.CreateUnknownValue()); + EXPECT_EQ(zero_value->kind(), Kind::kUnknown); + EXPECT_EQ(zero_value->type(), type_factory.GetUnknownType()); +} + Persistent MakeStringBytes(ValueFactory& value_factory, absl::string_view value) { return Must(value_factory.CreateBytesValue(value)); @@ -2463,6 +2483,7 @@ TEST_P(ValueTest, SupportsAbslHash) { Persistent(map_value), Persistent( value_factory.CreateTypeValue(type_factory.GetNullType())), + Persistent(value_factory.CreateUnknownValue()), })); } diff --git a/base/values/unknown_value.cc b/base/values/unknown_value.cc new file mode 100644 index 000000000..2aab7b4ab --- /dev/null +++ b/base/values/unknown_value.cc @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/values/unknown_value.h" + +#include +#include + +#include "absl/base/macros.h" + +namespace cel { + +CEL_INTERNAL_VALUE_IMPL(UnknownValue); + +UnknownValue::UnknownValue() : base_internal::HeapData(kKind) { + // Ensure `Value*` and `base_internal::HeapData*` are not thunked. + ABSL_ASSERT( + reinterpret_cast(static_cast(this)) == + reinterpret_cast(static_cast(this))); +} + +std::string UnknownValue::DebugString() const { return "*unknown*"; } + +void UnknownValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type()); +} + +bool UnknownValue::Equals(const Value& other) const { + return kind() == other.kind(); +} + +} // namespace cel diff --git a/base/values/unknown_value.h b/base/values/unknown_value.h new file mode 100644 index 000000000..a14c897b5 --- /dev/null +++ b/base/values/unknown_value.h @@ -0,0 +1,55 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_UNKNOWN_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_UNKNOWN_VALUE_H_ + +#include + +#include "absl/hash/hash.h" +#include "base/types/unknown_type.h" +#include "base/value.h" + +namespace cel { + +class UnknownValue final : public Value, public base_internal::HeapData { + public: + static constexpr Kind kKind = UnknownType::kKind; + + static bool Is(const Value& value) { return value.kind() == kKind; } + + constexpr Kind kind() const { return kKind; } + + const Persistent& type() const { + return UnknownType::Get(); + } + + std::string DebugString() const; + + void HashValue(absl::HashState state) const; + + bool Equals(const Value& other) const; + + private: + friend class cel::MemoryManager; + friend class ValueFactory; + + UnknownValue(); +}; + +CEL_INTERNAL_VALUE_DECL(UnknownValue); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_UNKNOWN_VALUE_H_ From 8f06a0ec784c808237fdd00d401d2abd33abf474 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 2 Aug 2022 21:01:21 +0000 Subject: [PATCH 021/303] internal testing change PiperOrigin-RevId: 464880898 --- conformance/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/conformance/BUILD b/conformance/BUILD index 9c2408c83..adc40627f 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -82,14 +82,17 @@ cc_binary( # TODO(issues/112): Unbound functions result in empty eval response. "--skip_test=basic/functions/unbound", "--skip_test=basic/functions/unbound_is_runtime_error", - # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. + # TODO(issues/116): dynamic/list/var and dynamic/value_struct/var fail to JSON parse correctly. "--skip_test=dynamic/list/var", + "--skip_test=dynamic/value_struct/var", # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails "--skip_test=fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "--skip_test=namespace/qualified/self_eval_qualified_lookup", "--skip_test=namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", # TODO(issues/117): Integer overflow on enum assignments should error. "--skip_test=enums/legacy_proto2/select_big,select_neg", + # TODO(issues/5): JSON to Proto parser/unparser doesn't support unknown enum values. + "--skip_test=enums/legacy_proto2/assign_standalone_int_big,assign_standalone_int_neg", # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. From 1d21fdcbd1cb0c92c7b66a39ddc43f89cb3afe16 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 5 Aug 2022 20:43:17 +0000 Subject: [PATCH 022/303] Prepare `UnknownAttributeSet`, `UnknownFunctionResultSet`, and friends for interoperation with the new API PiperOrigin-RevId: 465638656 --- ...ilder_short_circuiting_conformance_test.cc | 28 ++-- eval/compiler/flat_expr_builder_test.cc | 13 +- eval/eval/BUILD | 5 +- eval/eval/attribute_trail.cc | 28 ++-- eval/eval/attribute_trail.h | 16 +- eval/eval/attribute_trail_test.cc | 3 +- eval/eval/attribute_utility.cc | 53 ++++--- eval/eval/attribute_utility.h | 13 +- eval/eval/attribute_utility_test.cc | 25 ++-- eval/eval/comprehension_step.cc | 4 +- eval/eval/comprehension_step_test.cc | 10 +- eval/eval/create_list_step_test.cc | 8 +- eval/eval/evaluator_core_test.cc | 4 +- eval/eval/evaluator_stack.h | 14 +- eval/eval/evaluator_stack_test.cc | 9 +- eval/eval/function_step.cc | 2 +- eval/eval/ident_step.cc | 7 +- eval/eval/logic_step_test.cc | 24 ++- eval/eval/select_step.cc | 14 +- eval/eval/ternary_step_test.cc | 13 +- eval/public/BUILD | 6 +- eval/public/activation_test.cc | 4 +- eval/public/cel_attribute.cc | 139 +++++++++++++++++- eval/public/cel_attribute.h | 40 ++++- eval/public/cel_function.cc | 67 ++++++++- eval/public/cel_function.h | 38 +++-- eval/public/unknown_attribute_set.h | 71 ++++++--- eval/public/unknown_attribute_set_test.cc | 26 ++-- eval/public/unknown_function_result_set.cc | 68 +-------- eval/public/unknown_function_result_set.h | 96 +++++++++--- .../unknown_function_result_set_test.cc | 22 +-- eval/public/unknown_set.h | 17 +++ eval/public/unknown_set_test.cc | 26 ++-- eval/tests/unknowns_end_to_end_test.cc | 50 +++---- 34 files changed, 616 insertions(+), 347 deletions(-) diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index 708f1cb2a..5f75bab81 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -251,8 +251,8 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); - ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable_name(), testing::Eq("var1")); + ASSERT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, UnknownOr) { @@ -283,8 +283,8 @@ TEST_P(ShortCircuitingTest, UnknownOr) { ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); - ASSERT_THAT(attrs.attributes(), testing::SizeIs(1)); - EXPECT_THAT(attrs.attributes()[0]->variable_name(), testing::Eq("var1")); + ASSERT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, BasicTernary) { @@ -364,10 +364,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs[0]->variable_name(), Eq("cond")); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("cond")); // Unknown branches are discarded if condition is unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {}), @@ -377,10 +376,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs2 = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs2 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs2, SizeIs(1)); - EXPECT_THAT(attrs2[0]->variable_name(), Eq("cond")); + EXPECT_THAT(attrs2.begin()->variable_name(), Eq("cond")); } TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { @@ -413,10 +411,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs3 = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs3 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs3, SizeIs(1)); - EXPECT_EQ(attrs3[0]->variable_name(), "arg2"); + EXPECT_EQ(attrs3.begin()->variable_name(), "arg2"); } TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { @@ -451,10 +448,9 @@ TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); - EXPECT_EQ(attrs[0]->variable_name(), "cond"); + EXPECT_EQ(attrs.begin()->variable_name(), "cond"); } const char* TestName(testing::TestParamInfo info) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 404005195..ac651d889 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1695,9 +1695,9 @@ TEST(FlatExprBuilderTest, Ternary) { value2.mutable_ident_expr()->set_name("value2"); CelAttribute value2_attr(value2, {}); - UnknownSet unknown_selector(UnknownAttributeSet({&selector_attr})); - UnknownSet unknown_value1(UnknownAttributeSet({&value1_attr})); - UnknownSet unknown_value2(UnknownAttributeSet({&value2_attr})); + UnknownSet unknown_selector(UnknownAttributeSet({selector_attr})); + UnknownSet unknown_value1(UnknownAttributeSet({value1_attr})); + UnknownSet unknown_value2(UnknownAttributeSet({value2_attr})); CelValue result; ASSERT_OK(RunTernaryExpression( CelValue::CreateUnknownSet(&unknown_selector), @@ -1705,10 +1705,9 @@ TEST(FlatExprBuilderTest, Ternary) { CelValue::CreateUnknownSet(&unknown_value2), &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); - EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(1)); - EXPECT_THAT( - result_set->unknown_attributes().attributes()[0]->variable_name(), - Eq("selector")); + EXPECT_THAT(result_set->unknown_attributes().size(), Eq(1)); + EXPECT_THAT(result_set->unknown_attributes().begin()->variable_name(), + Eq("selector")); } } diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 1f4719042..ca4aabb15 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -183,7 +183,6 @@ cc_library( "//eval/public:cel_value", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", - "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -577,8 +576,10 @@ cc_library( "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/utility", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -613,7 +614,6 @@ cc_library( "//eval/public:unknown_attribute_set", "//eval/public:unknown_function_result_set", "//eval/public:unknown_set", - "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", @@ -628,7 +628,6 @@ cc_test( ], deps = [ ":attribute_utility", - "//base:memory_manager", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc index 6cbe0069a..c8023eacc 100644 --- a/eval/eval/attribute_trail.cc +++ b/eval/eval/attribute_trail.cc @@ -1,7 +1,11 @@ #include "eval/eval/attribute_trail.h" +#include +#include #include +#include +#include "absl/base/attributes.h" #include "absl/status/status.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" @@ -9,24 +13,26 @@ namespace google::api::expr::runtime { AttributeTrail::AttributeTrail(google::api::expr::v1alpha1::Expr root, - cel::MemoryManager& manager) { - attribute_ = manager - .New(std::move(root), - std::vector()) - .release(); + cel::MemoryManager& manager + ABSL_ATTRIBUTE_UNUSED) { + attribute_.emplace(std::move(root), std::vector()); } // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, - cel::MemoryManager& manager) const { + cel::MemoryManager& manager + ABSL_ATTRIBUTE_UNUSED) const { // Cannot continue void trail if (empty()) return AttributeTrail(); - std::vector qualifiers = attribute_->qualifier_path(); - qualifiers.push_back(qualifier); - auto attribute = manager.New( - std::string(attribute_->variable_name()), std::move(qualifiers)); - return AttributeTrail(attribute.release()); + std::vector qualifiers; + qualifiers.reserve(attribute_->qualifier_path().size() + 1); + std::copy_n(attribute_->qualifier_path().begin(), + attribute_->qualifier_path().size(), + std::back_inserter(qualifiers)); + qualifiers.push_back(std::move(qualifier)); + return AttributeTrail(CelAttribute(std::string(attribute_->variable_name()), + std::move(qualifiers))); } } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index 875979e06..c537cdf76 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -8,6 +8,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/types/optional.h" +#include "absl/utility/utility.h" #include "base/memory_manager.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" @@ -27,10 +28,13 @@ namespace google::api::expr::runtime { // or supported. class AttributeTrail { public: - AttributeTrail() : attribute_(nullptr) {} + AttributeTrail() = default; AttributeTrail(google::api::expr::v1alpha1::Expr root, cel::MemoryManager& manager); + explicit AttributeTrail(std::string variable_name) + : attribute_(absl::in_place, std::move(variable_name)) {} + // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(CelAttributeQualifier qualifier, cel::MemoryManager& manager) const; @@ -44,14 +48,14 @@ class AttributeTrail { } // Returns CelAttribute that corresponds to content of AttributeTrail. - const CelAttribute* attribute() const { return attribute_; } + const CelAttribute& attribute() const { return attribute_.value(); } - bool empty() const { return attribute_ == nullptr; } + bool empty() const { return !attribute_.has_value(); } private: - explicit AttributeTrail(const CelAttribute* attribute) - : attribute_(attribute) {} - const CelAttribute* attribute_; + explicit AttributeTrail(CelAttribute attribute) + : attribute_(std::move(attribute)) {} + absl::optional attribute_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc index adb982860..ba0b2fcaf 100644 --- a/eval/eval/attribute_trail_test.cc +++ b/eval/eval/attribute_trail_test.cc @@ -36,8 +36,7 @@ TEST(AttributeTrailTest, AttributeTrailStep) { root.mutable_ident_expr()->set_name("ident"); AttributeTrail trail = AttributeTrail(root, manager).Step(&step, manager); - ASSERT_TRUE(trail.attribute() != nullptr); - ASSERT_EQ(*trail.attribute(), + ASSERT_EQ(trail.attribute(), CelAttribute(root, {CelAttributeQualifier::Create(step_value)})); } diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 69e7813e0..95de45708 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -1,14 +1,13 @@ #include "eval/eval/attribute_utility.h" -#include "absl/status/status.h" +#include + #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" namespace google::api::expr::runtime { -using ::google::protobuf::Arena; - bool AttributeUtility::CheckForMissingAttribute( const AttributeTrail& trail) const { if (trail.empty()) { @@ -19,7 +18,7 @@ bool AttributeUtility::CheckForMissingAttribute( // (b/161297249) Preserving existing behavior for now, will add a streamz // for partial match, follow up with tightening up which fields are exposed // to the condition (w/ ajay and jim) - if (pattern.IsMatch(*trail.attribute()) == + if (pattern.IsMatch(trail.attribute()) == CelAttributePattern::MatchType::FULL) { return true; } @@ -34,7 +33,7 @@ bool AttributeUtility::CheckForUnknown(const AttributeTrail& trail, return false; } for (const auto& pattern : *unknown_patterns_) { - auto current_match = pattern.IsMatch(*trail.attribute()); + auto current_match = pattern.IsMatch(trail.attribute()); if (current_match == CelAttributePattern::MatchType::FULL || (use_partial && current_match == CelAttributePattern::MatchType::PARTIAL)) { @@ -50,20 +49,28 @@ bool AttributeUtility::CheckForUnknown(const AttributeTrail& trail, // Returns pointer to merged set or nullptr, if there were no sets to merge. const UnknownSet* AttributeUtility::MergeUnknowns( absl::Span args, const UnknownSet* initial_set) const { - const UnknownSet* result = initial_set; + absl::optional result_set; for (const auto& value : args) { if (!value.IsUnknownSet()) continue; auto current_set = value.UnknownSetOrDie(); - if (result == nullptr) { - result = current_set; - } else { - result = memory_manager_.New(*result, *current_set).release(); + if (!result_set.has_value()) { + if (initial_set != nullptr) { + result_set.emplace(*initial_set); + } else { + result_set.emplace(); + } } + result_set->Add(*current_set); } - return result; + if (!result_set.has_value()) { + return initial_set; + } + + return memory_manager_.New(std::move(result_set).value()) + .release(); } // Creates merged UnknownAttributeSet. @@ -73,15 +80,15 @@ const UnknownSet* AttributeUtility::MergeUnknowns( // Returns pointer to merged set or nullptr, if there were no sets to merge. UnknownAttributeSet AttributeUtility::CheckForUnknowns( absl::Span args, bool use_partial) const { - std::vector unknown_attrs; + UnknownAttributeSet attribute_set; - for (auto trail : args) { + for (const auto& trail : args) { if (CheckForUnknown(trail, use_partial)) { - unknown_attrs.push_back(trail.attribute()); + attribute_set.Add(trail.attribute()); } } - return UnknownAttributeSet(unknown_attrs); + return attribute_set; } // Creates merged UnknownAttributeSet. @@ -94,14 +101,18 @@ const UnknownSet* AttributeUtility::MergeUnknowns( absl::Span args, absl::Span attrs, const UnknownSet* initial_set, bool use_partial) const { UnknownAttributeSet attr_set = CheckForUnknowns(attrs, use_partial); - if (!attr_set.attributes().empty()) { + if (!attr_set.empty()) { + UnknownSet result_set(std::move(attr_set)); if (initial_set != nullptr) { - initial_set = - memory_manager_.New(*initial_set, UnknownSet(attr_set)) - .release(); - } else { - initial_set = memory_manager_.New(attr_set).release(); + result_set.Add(*initial_set); + } + for (const auto& value : args) { + if (!value.IsUnknownSet()) { + continue; + } + result_set.Add(*value.UnknownSetOrDie()); } + return memory_manager_.New(std::move(result_set)).release(); } return MergeUnknowns(args, initial_set); } diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 906e8ad06..b8b0863b6 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ +#include #include #include "google/protobuf/arena.h" @@ -69,8 +70,9 @@ class AttributeUtility { bool use_partial) const; // Create an initial UnknownSet from a single attribute. - const UnknownSet* CreateUnknownSet(const CelAttribute* attr) const { - return memory_manager_.New(UnknownAttributeSet({attr})) + const UnknownSet* CreateUnknownSet(CelAttribute attr) const { + return memory_manager_ + .New(UnknownAttributeSet({std::move(attr)})) .release(); } @@ -78,10 +80,9 @@ class AttributeUtility { const UnknownSet* CreateUnknownSet(const CelFunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const { - auto* fn = - memory_manager_.New(fn_descriptor, expr_id) - .release(); - return memory_manager_.New(UnknownFunctionResultSet(fn)) + return memory_manager_ + .New(UnknownFunctionResultSet( + UnknownFunctionResult(fn_descriptor, expr_id))) .release(); } diff --git a/eval/eval/attribute_utility_test.cc b/eval/eval/attribute_utility_test.cc index fc80fd2ab..172a7fbe1 100644 --- a/eval/eval/attribute_utility_test.cc +++ b/eval/eval/attribute_utility_test.cc @@ -85,9 +85,9 @@ TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, manager); - UnknownSet unknown_set0(UnknownAttributeSet({&attribute0})); - UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); - UnknownSet unknown_set2(UnknownAttributeSet({&attribute1, &attribute2})); + UnknownSet unknown_set0(UnknownAttributeSet({attribute0})); + UnknownSet unknown_set1(UnknownAttributeSet({attribute1})); + UnknownSet unknown_set2(UnknownAttributeSet({attribute1, attribute2})); std::vector values = { CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), @@ -97,16 +97,16 @@ TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { const UnknownSet* unknown_set = utility.MergeUnknowns(values, nullptr); ASSERT_THAT(unknown_set, NotNull()); - ASSERT_THAT(unknown_set->unknown_attributes().attributes(), - UnorderedPointwise(Eq(), std::vector{ - &attribute0, &attribute1})); + ASSERT_THAT(unknown_set->unknown_attributes(), + UnorderedPointwise( + Eq(), std::vector{attribute0, attribute1})); unknown_set = utility.MergeUnknowns(values, &unknown_set2); ASSERT_THAT(unknown_set, NotNull()); ASSERT_THAT( - unknown_set->unknown_attributes().attributes(), - UnorderedPointwise(Eq(), std::vector{ - &attribute0, &attribute1, &attribute2})); + unknown_set->unknown_attributes(), + UnorderedPointwise( + Eq(), std::vector{attribute0, attribute1, attribute2})); } TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { @@ -130,7 +130,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { AttributeTrail trail1(unknown_expr1, manager); CelAttribute attribute1(unknown_expr1, {}); - UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); + UnknownSet unknown_set1(UnknownAttributeSet({attribute1})); AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, manager); @@ -147,7 +147,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { UnknownSet unknown_set(unknown_set1, unknown_attr_set); - ASSERT_THAT(unknown_set.unknown_attributes().attributes(), SizeIs(3)); + ASSERT_THAT(unknown_set.unknown_attributes(), SizeIs(3)); } TEST(UnknownsUtilityTest, UnknownsUtilityCheckForMissingAttributes) { @@ -201,8 +201,7 @@ TEST(AttributeUtilityTest, CreateUnknownSet) { AttributeUtility utility(&empty_patterns, &empty_patterns, manager); const UnknownSet* set = utility.CreateUnknownSet(trail.attribute()); - EXPECT_EQ(*set->unknown_attributes().attributes().at(0)->AsString(), - "destination.ip"); + EXPECT_EQ(*set->unknown_attributes().begin()->AsString(), "destination.ip"); } } // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 6a3d1ec3b..6f657aed3 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -130,11 +130,11 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { CelValue current_value = (*cel_list)[current_index]; frame->value_stack().Push(CelValue::CreateInt64(current_index)); - auto iter_trail = iter_range_attr.Step( + AttributeTrail iter_trail = iter_range_attr.Step( CelAttributeQualifier::Create(CelValue::CreateInt64(current_index)), frame->memory_manager()); frame->value_stack().Push(current_value, iter_trail); - CEL_RETURN_IF_ERROR(frame->SetIterVar(current_value, iter_trail)); + CEL_RETURN_IF_ERROR(frame->SetIterVar(current_value, std::move(iter_trail))); return absl::OkStatus(); } diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 38d3b2340..a1595aa69 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -154,12 +154,11 @@ TEST_F(ListKeysStepTest, MapPartiallyUnknown) { ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); - const auto& attrs = - eval_result->UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = eval_result->UnknownSetOrDie()->unknown_attributes(); EXPECT_THAT(attrs, SizeIs(1)); - EXPECT_THAT(attrs.at(0)->variable_name(), Eq("var")); - EXPECT_THAT(attrs.at(0)->qualifier_path(), SizeIs(0)); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("var")); + EXPECT_THAT(attrs.begin()->qualifier_path(), SizeIs(0)); } TEST_F(ListKeysStepTest, ErrorPassedThrough) { @@ -209,8 +208,7 @@ TEST_F(ListKeysStepTest, UnknownSetPassedThrough) { ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); - EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes().attributes(), - SizeIs(1)); + EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes(), SizeIs(1)); } } // namespace diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 516f68cb1..7eb613800 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -147,7 +147,7 @@ TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { Expr expr0; expr0.mutable_ident_expr()->set_name("name0"); CelAttribute attr0(expr0, {}); - UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); + UnknownSet unknown_set0(UnknownAttributeSet({attr0})); values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); CelError error = absl::InvalidArgumentError("bad arg"); values.push_back(CelValue::CreateError(&error)); @@ -185,8 +185,8 @@ TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { Expr expr1; expr1.mutable_ident_expr()->set_name("name1"); CelAttribute attr1(expr1, {}); - UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); - UnknownSet unknown_set1(UnknownAttributeSet({&attr1})); + UnknownSet unknown_set0(UnknownAttributeSet({attr0})); + UnknownSet unknown_set1(UnknownAttributeSet({attr1})); for (size_t i = 0; i < 100; i++) { values.push_back(CelValue::CreateInt64(i)); } @@ -197,7 +197,7 @@ TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { RunExpressionWithCelValues(values, &arena, true)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); - EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(2)); + EXPECT_THAT(result_set->unknown_attributes().size(), Eq(2)); } INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 5c8dd7450..63457aeba 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -132,8 +132,8 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { int64_t result_value; ASSERT_TRUE(result.GetValue(&result_value)); EXPECT_EQ(test_value, result_value); - ASSERT_TRUE(trail->attribute()->has_variable_name()); - ASSERT_EQ(trail->attribute()->variable_name(), "var"); + ASSERT_TRUE(trail->attribute().has_variable_name()); + ASSERT_EQ(trail->attribute().variable_name(), "var"); // Test that it goes away properly ASSERT_OK(frame.ClearIterVar()); diff --git a/eval/eval/evaluator_stack.h b/eval/eval/evaluator_stack.h index 8e67bfa39..1ecab27a3 100644 --- a/eval/eval/evaluator_stack.h +++ b/eval/eval/evaluator_stack.h @@ -1,6 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ +#include +#include #include #include "absl/types/span.h" @@ -82,7 +84,13 @@ class EvaluatorStack { LOG(ERROR) << "Trying to pop more elements (" << size << ") than the current stack size: " << current_size_; } - current_size_ -= size; + while (size > 0) { + size_t position = current_size_ - 1; + stack_[position] = CelValue::CreateNull(); + attribute_stack_[position] = AttributeTrail(); + current_size_--; + size--; + } } // Put element on the top of the stack. @@ -93,7 +101,7 @@ class EvaluatorStack { LOG(ERROR) << "No room to push more elements on to EvaluatorStack"; } stack_[current_size_] = value; - attribute_stack_[current_size_] = attribute; + attribute_stack_[current_size_] = std::move(attribute); current_size_++; } @@ -110,7 +118,7 @@ class EvaluatorStack { LOG(ERROR) << "Cannot PopAndPush on empty stack."; } stack_[current_size_ - 1] = value; - attribute_stack_[current_size_ - 1] = attribute; + attribute_stack_[current_size_ - 1] = std::move(attribute); } // Preallocate stack. diff --git a/eval/eval/evaluator_stack_test.cc b/eval/eval/evaluator_stack_test.cc index 98620041b..aa008c576 100644 --- a/eval/eval/evaluator_stack_test.cc +++ b/eval/eval/evaluator_stack_test.cc @@ -8,7 +8,6 @@ namespace google::api::expr::runtime { namespace { using ::cel::extensions::ProtoMemoryManager; -using testing::NotNull; // Test Value Stack Push/Pop operation TEST(EvaluatorStackTest, StackPushPop) { @@ -23,18 +22,18 @@ TEST(EvaluatorStackTest, StackPushPop) { stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, manager)); ASSERT_EQ(stack.Peek().Int64OrDie(), 3); - ASSERT_THAT(stack.PeekAttribute().attribute(), NotNull()); - ASSERT_EQ(*stack.PeekAttribute().attribute(), attribute); + ASSERT_FALSE(stack.PeekAttribute().empty()); + ASSERT_EQ(stack.PeekAttribute().attribute(), attribute); stack.Pop(1); ASSERT_EQ(stack.Peek().Int64OrDie(), 2); - ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); + ASSERT_TRUE(stack.PeekAttribute().empty()); stack.Pop(1); ASSERT_EQ(stack.Peek().Int64OrDie(), 1); - ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); + ASSERT_TRUE(stack.PeekAttribute().empty()); } // Test that inner stacks within value stack retain the equality of their sizes. diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index c305559c7..57521a134 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -73,7 +73,7 @@ std::vector CheckForPartialUnknowns( for (size_t i = 0; i < args.size(); i++) { auto attr_set = frame->attribute_utility().CheckForUnknowns( attrs.subspan(i, 1), /*use_partial=*/true); - if (!attr_set.attributes().empty()) { + if (!attr_set.empty()) { auto unknown_set = frame->memory_manager() .New(std::move(attr_set)) .release(); diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index d3fd44b68..722dd8fd5 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -2,6 +2,7 @@ #include #include +#include #include "google/protobuf/arena.h" #include "absl/status/status.h" @@ -54,9 +55,7 @@ absl::Status IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, // Populate trails if either MissingAttributeError or UnknownPattern // is enabled. if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { - google::api::expr::v1alpha1::Expr expr; - expr.mutable_ident_expr()->set_name(name_); - *trail = AttributeTrail(std::move(expr), frame->memory_manager()); + *trail = AttributeTrail(name_); } if (frame->enable_missing_attribute_errors() && !name_.empty() && @@ -91,7 +90,7 @@ absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { CEL_RETURN_IF_ERROR(DoEvaluate(frame, &result, &trail)); - frame->value_stack().Push(result, trail); + frame->value_stack().Push(result, std::move(trail)); return absl::OkStatus(); } diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 7584a4219..b491a9d1f 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -217,22 +217,20 @@ TEST_F(LogicStepTest, TestAndLogicUnknownHandling) { ident_expr1->set_name("name1"); CelAttribute attr0(expr0, {}), attr1(expr1, {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), false, &result, true); ASSERT_OK(status); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_THAT( - result.UnknownSetOrDie()->unknown_attributes().attributes().size(), - Eq(2)); + ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { @@ -280,23 +278,21 @@ TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { ident_expr1->set_name("name1"); CelAttribute attr0(expr0, {}), attr1(expr1, {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), true, &result, true); ASSERT_OK(status); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_THAT( - result.UnknownSetOrDie()->unknown_attributes().attributes().size(), - Eq(2)); + ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index e3be86bb8..54d7c4abf 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -85,7 +85,7 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, if (frame->enable_missing_attribute_errors() && frame->attribute_utility().CheckForMissingAttribute(trail)) { - auto attribute_string = trail.attribute()->AsString(); + auto attribute_string = trail.attribute().AsString(); if (attribute_string.ok()) { return CreateMissingAttributeError(frame->memory_manager(), *attribute_string); @@ -153,7 +153,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (arg.IsNull()) { CelValue error_value = CreateErrorValue(frame->memory_manager(), "Message is NULL"); - frame->value_stack().PopAndPush(error_value, result_trail); + frame->value_stack().PopAndPush(error_value, std::move(result_trail)); return absl::OkStatus(); } @@ -165,7 +165,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { CheckForMarkedAttributes(result_trail, frame); if (marked_attribute_check.has_value()) { frame->value_stack().PopAndPush(marked_attribute_check.value(), - result_trail); + std::move(result_trail)); return absl::OkStatus(); } @@ -175,7 +175,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (arg.MapOrDie() == nullptr) { frame->value_stack().PopAndPush( CreateErrorValue(frame->memory_manager(), "Map is NULL"), - result_trail); + std::move(result_trail)); return absl::OkStatus(); } break; @@ -185,7 +185,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { arg.GetValue(&w) && w.message_ptr() == nullptr) { frame->value_stack().PopAndPush( CreateErrorValue(frame->memory_manager(), "Message is NULL"), - result_trail); + std::move(result_trail)); return absl::OkStatus(); } break; @@ -218,7 +218,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { CEL_RETURN_IF_ERROR( CreateValueFromField(wrapper, frame->memory_manager(), &result)); - frame->value_stack().PopAndPush(result, result_trail); + frame->value_stack().PopAndPush(result, std::move(result_trail)); return absl::OkStatus(); } @@ -235,7 +235,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { } else { result = CreateNoSuchKeyError(frame->memory_manager(), field_); } - frame->value_stack().PopAndPush(result, result_trail); + frame->value_stack().PopAndPush(result, std::move(result_trail)); return absl::OkStatus(); } default: diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index e0ccb44e1..4cd694cd4 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -147,22 +147,21 @@ TEST_F(LogicStepTest, TestUnknownHandling) { CelAttribute attr0(expr0.ident_expr().name(), {}), attr1(expr1.ident_expr().name(), {}); - UnknownAttributeSet unknown_attr_set0({&attr0}); - UnknownAttributeSet unknown_attr_set1({&attr1}); + UnknownAttributeSet unknown_attr_set0({attr0}); + UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); - EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); - EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); ASSERT_OK(EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); - const auto& attrs = - result.UnknownSetOrDie()->unknown_attributes().attributes(); + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); - EXPECT_THAT(attrs[0]->variable_name(), Eq("name0")); + EXPECT_THAT(attrs.begin()->variable_name(), Eq("name0")); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); diff --git a/eval/public/BUILD b/eval/public/BUILD index 73d99dee4..de25a91d9 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -123,13 +123,12 @@ cc_library( cc_library( name = "unknown_attribute_set", - srcs = [ - ], hdrs = [ "unknown_attribute_set.h", ], deps = [ ":cel_attribute", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -907,9 +906,6 @@ cc_library( hdrs = ["unknown_function_result_set.h"], deps = [ ":cel_function", - ":cel_options", - ":cel_value", - ":set_util", "@com_google_absl//absl/container:btree", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index e225ea05a..010d7d6d1 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -223,13 +223,13 @@ TEST(ActivationTest, ErrorPathTest) { trail = trail.Step( CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), manager); - ASSERT_EQ(destination_ip_pattern.IsMatch(*trail.attribute()), + ASSERT_EQ(destination_ip_pattern.IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); EXPECT_TRUE(activation.missing_attribute_patterns().empty()); activation.set_missing_attribute_patterns({destination_ip_pattern}); EXPECT_EQ( - activation.missing_attribute_patterns()[0].IsMatch(*trail.attribute()), + activation.missing_attribute_patterns()[0].IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); } diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 2498183f2..60abb40e9 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -130,8 +130,104 @@ struct CelAttributeQualifierIsMatchVisitor final { } }; +struct CelAttributeQualifierTypeComparator final { + const CelValue::Type lhs; + + bool operator()(const CelValue::Type& rhs) const { + return static_cast(lhs) < static_cast(rhs); + } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t other) const { return false; } + + bool operator()(const std::string&) const { return false; } + + bool operator()(bool other) const { return false; } +}; + +struct CelAttributeQualifierIntComparator final { + const int64_t lhs; + + bool operator()(const CelValue::Type&) const { return true; } + + bool operator()(int64_t rhs) const { return lhs < rhs; } + + bool operator()(uint64_t) const { return true; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool) const { return false; } +}; + +struct CelAttributeQualifierUintComparator final { + const uint64_t lhs; + + bool operator()(const CelValue::Type&) const { return true; } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t rhs) const { return lhs < rhs; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool) const { return false; } +}; + +struct CelAttributeQualifierStringComparator final { + const std::string& lhs; + + bool operator()(const CelValue::Type&) const { return true; } + + bool operator()(int64_t) const { return false; } + + bool operator()(uint64_t) const { return false; } + + bool operator()(const std::string& rhs) const { return lhs < rhs; } + + bool operator()(bool) const { return false; } +}; + +struct CelAttributeQualifierBoolComparator final { + const bool lhs; + + bool operator()(const CelValue::Type&) const { return true; } + + bool operator()(int64_t) const { return true; } + + bool operator()(uint64_t) const { return true; } + + bool operator()(const std::string&) const { return true; } + + bool operator()(bool rhs) const { return lhs < rhs; } +}; + } // namespace +struct CelAttributeQualifier::ComparatorVisitor final { + const CelAttributeQualifier::Variant& rhs; + + bool operator()(const CelValue::Type& lhs) const { + return absl::visit(CelAttributeQualifierTypeComparator{lhs}, rhs); + } + + bool operator()(int64_t lhs) const { + return absl::visit(CelAttributeQualifierIntComparator{lhs}, rhs); + } + + bool operator()(uint64_t lhs) const { + return absl::visit(CelAttributeQualifierUintComparator{lhs}, rhs); + } + + bool operator()(const std::string& lhs) const { + return absl::visit(CelAttributeQualifierStringComparator{lhs}, rhs); + } + + bool operator()(bool lhs) const { + return absl::visit(CelAttributeQualifierBoolComparator{lhs}, rhs); + } +}; + CelValue::Type CelAttributeQualifier::type() const { return std::visit(CelAttributeQualifierTypeVisitor{}, value_); } @@ -154,6 +250,14 @@ CelAttributeQualifier CelAttributeQualifier::Create(CelValue value) { } } +bool CelAttributeQualifier::operator<( + const CelAttributeQualifier& other) const { + // The order is not publicly documented because it is subject to change. + // Currently we sort in the following order, with each type being sorted + // against itself: bool, int, uint, string, type. + return absl::visit(ComparatorVisitor{other.value_}, value_); +} + CelAttributePattern CreateCelAttributePattern( absl::string_view variable, std::initializer_list CelAttribute::AsString() const { if (variable_name().empty()) { return absl::InvalidArgumentError( @@ -194,7 +331,7 @@ const absl::StatusOr CelAttribute::AsString() const { std::string result = std::string(variable_name()); - for (const auto& qualifier : qualifier_path_) { + for (const auto& qualifier : qualifier_path()) { CEL_RETURN_IF_ERROR( std::visit(CelAttributeStringPrinter(&result, qualifier.type()), qualifier.value_)); diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index fa9a29a37..2ce114dc0 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,12 @@ namespace google::api::expr::runtime { // attribute resolutuion path. A segment can be qualified by values of // following types: string/int64_t/uint64/bool. class CelAttributeQualifier { + private: + struct ComparatorVisitor; + + using Variant = + std::variant; + public: // Factory method. static CelAttributeQualifier Create(CelValue value); @@ -72,6 +79,8 @@ class CelAttributeQualifier { return IsMatch(other); } + bool operator<(const CelAttributeQualifier& other) const; + bool IsMatch(const CelValue& cel_value) const; bool IsMatch(absl::string_view other_key) const { @@ -81,6 +90,7 @@ class CelAttributeQualifier { private: friend class CelAttribute; + friend struct ComparatorVisitor; CelAttributeQualifier() = default; @@ -94,7 +104,7 @@ class CelAttributeQualifier { // instances, regardless of whether they are supported in this context or not. // We represented unsupported types by using the first alternative and thus // preserve backwards compatibility with the result of `type()` above. - std::variant value_; + Variant value_; }; // CelAttributeQualifierPattern matches a segment in @@ -152,30 +162,44 @@ class CelAttributeQualifierPattern { // CelAttribute represents resolved attribute path. class CelAttribute { public: + explicit CelAttribute(std::string variable_name) + : CelAttribute(std::move(variable_name), {}) {} + CelAttribute(std::string variable_name, std::vector qualifier_path) - : variable_name_(std::move(variable_name)), - qualifier_path_(std::move(qualifier_path)) {} + : impl_(std::make_shared(std::move(variable_name), + std::move(qualifier_path))) {} CelAttribute(const google::api::expr::v1alpha1::Expr& variable, std::vector qualifier_path) : CelAttribute(variable.ident_expr().name(), std::move(qualifier_path)) {} - absl::string_view variable_name() const { return variable_name_; } + absl::string_view variable_name() const { return impl_->variable_name; } - bool has_variable_name() const { return !variable_name_.empty(); } + bool has_variable_name() const { return !impl_->variable_name.empty(); } const std::vector& qualifier_path() const { - return qualifier_path_; + return impl_->qualifier_path; } bool operator==(const CelAttribute& other) const; + bool operator<(const CelAttribute& other) const; + const absl::StatusOr AsString() const; private: - std::string variable_name_; - std::vector qualifier_path_; + struct Impl final { + Impl(std::string variable_name, + std::vector qualifier_path) + : variable_name(std::move(variable_name)), + qualifier_path(std::move(qualifier_path)) {} + + std::string variable_name; + std::vector qualifier_path; + }; + + std::shared_ptr impl_; }; // CelAttributePattern is a fully-qualified absolute attribute path pattern. diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index 75370e8df..2c8df134d 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -1,19 +1,25 @@ #include "eval/public/cel_function.h" +#include +#include +#include +#include +#include + namespace google::api::expr::runtime { bool CelFunctionDescriptor::ShapeMatches( bool receiver_style, const std::vector& types) const { - if (receiver_style_ != receiver_style) { + if (this->receiver_style() != receiver_style) { return false; } - if (types_.size() != types.size()) { + if (this->types().size() != types.size()) { return false; } - for (size_t i = 0; i < types_.size(); i++) { - CelValue::Type this_type = types_[i]; + for (size_t i = 0; i < this->types().size(); i++) { + CelValue::Type this_type = this->types()[i]; CelValue::Type other_type = types[i]; if (this_type != CelValue::Type::kAny && other_type != CelValue::Type::kAny && this_type != other_type) { @@ -23,6 +29,59 @@ bool CelFunctionDescriptor::ShapeMatches( return true; } +bool CelFunctionDescriptor::operator==( + const CelFunctionDescriptor& other) const { + return impl_.get() == other.impl_.get() || + (name() == other.name() && + receiver_style() == other.receiver_style() && + types().size() == other.types().size() && + std::equal(types().begin(), types().end(), other.types().begin())); +} + +bool CelFunctionDescriptor::operator<( + const CelFunctionDescriptor& other) const { + if (impl_.get() == other.impl_.get()) { + return false; + } + if (name() < other.name()) { + return true; + } + if (name() != other.name()) { + return false; + } + if (receiver_style() < other.receiver_style()) { + return true; + } + if (receiver_style() != other.receiver_style()) { + return false; + } + auto lhs_begin = types().begin(); + auto lhs_end = types().end(); + auto rhs_begin = other.types().begin(); + auto rhs_end = other.types().end(); + while (lhs_begin != lhs_end && rhs_begin != rhs_end) { + if (*lhs_begin < *rhs_begin) { + return true; + } + if (!(*lhs_begin == *rhs_begin)) { + return false; + } + lhs_begin++; + rhs_begin++; + } + if (lhs_begin == lhs_end && rhs_begin == rhs_end) { + // Neither has any elements left, they are equal. + return false; + } + if (lhs_begin == lhs_end) { + // Left has no more elements. Right is greater. + return true; + } + // Right has no more elements. Left is greater. + ABSL_ASSERT(rhs_begin == rhs_end); + return false; +} + bool CelFunction::MatchArguments(absl::Span arguments) const { auto types_size = descriptor().types().size(); diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index d60a107e3..c8294ab5d 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ +#include #include #include #include @@ -19,24 +20,22 @@ class CelFunctionDescriptor { CelFunctionDescriptor(absl::string_view name, bool receiver_style, std::vector types, bool is_strict = true) - : name_(name), - receiver_style_(receiver_style), - types_(std::move(types)), - is_strict_(is_strict) {} + : impl_(std::make_shared(name, receiver_style, std::move(types), + is_strict)) {} // Function name. - const std::string& name() const { return name_; } + const std::string& name() const { return impl_->name; } // Whether function is receiver style i.e. true means arg0.name(args[1:]...). - bool receiver_style() const { return receiver_style_; } + bool receiver_style() const { return impl_->receiver_style; } // The argmument types the function accepts. - const std::vector& types() const { return types_; } + const std::vector& types() const { return impl_->types; } // if true (strict, default), error or unknown arguments are propagated // instead of calling the function. if false (non-strict), the function may // receive error or unknown values as arguments. - bool is_strict() const { return is_strict_; } + bool is_strict() const { return impl_->is_strict; } // Helper for matching a descriptor. This tests that the shape is the same -- // |other| accepts the same number and types of arguments and is the same call @@ -47,11 +46,26 @@ class CelFunctionDescriptor { bool ShapeMatches(bool receiver_style, const std::vector& types) const; + bool operator==(const CelFunctionDescriptor& other) const; + + bool operator<(const CelFunctionDescriptor& other) const; + private: - std::string name_; - bool receiver_style_; - std::vector types_; - bool is_strict_; + struct Impl final { + Impl(absl::string_view name, bool receiver_style, + std::vector types, bool is_strict) + : name(name), + receiver_style(receiver_style), + types(std::move(types)), + is_strict(is_strict) {} + + std::string name; + bool receiver_style; + std::vector types; + bool is_strict; + }; + + std::shared_ptr impl_; }; // CelFunction is a handler that represents single diff --git a/eval/public/unknown_attribute_set.h b/eval/public/unknown_attribute_set.h index a661de69f..4ccf0873c 100644 --- a/eval/public/unknown_attribute_set.h +++ b/eval/public/unknown_attribute_set.h @@ -2,7 +2,9 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ #include +#include +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_set.h" #include "eval/public/cel_attribute.h" @@ -11,17 +13,28 @@ namespace api { namespace expr { namespace runtime { +class AttributeUtility; +class UnknownSet; + // UnknownAttributeSet is a container for CEL attributes that are identified as // unknown during expression evaluation. -class UnknownAttributeSet { +class UnknownAttributeSet final { + private: + using Container = absl::btree_set; + public: - UnknownAttributeSet(const UnknownAttributeSet& other) = default; - UnknownAttributeSet& operator=(const UnknownAttributeSet& other) = default; + using value_type = typename Container::value_type; + using size_type = typename Container::size_type; + using iterator = typename Container::const_iterator; + using const_iterator = typename Container::const_iterator; + + UnknownAttributeSet() = default; + UnknownAttributeSet(const UnknownAttributeSet&) = default; + UnknownAttributeSet(UnknownAttributeSet&&) = default; + UnknownAttributeSet& operator=(const UnknownAttributeSet&) = default; + UnknownAttributeSet& operator=(UnknownAttributeSet&&) = default; - UnknownAttributeSet() {} - explicit UnknownAttributeSet( - const std::vector& attributes) { - attributes_.reserve(attributes.size()); + explicit UnknownAttributeSet(const std::vector& attributes) { for (const auto& attr : attributes) { Add(attr); } @@ -29,14 +42,31 @@ class UnknownAttributeSet { UnknownAttributeSet(const UnknownAttributeSet& set1, const UnknownAttributeSet& set2) - : attributes_(set1.attributes()) { - attributes_.reserve(set1.attributes().size() + set2.attributes().size()); - for (const auto& attr : set2.attributes()) { + : attributes_(set1.attributes_) { + for (const auto& attr : set2.attributes_) { Add(attr); } } - std::vector attributes() const { return attributes_; } + iterator begin() const { return attributes_.begin(); } + + const_iterator cbegin() const { return attributes_.cbegin(); } + + iterator end() const { return attributes_.end(); } + + const_iterator cend() const { return attributes_.cend(); } + + size_type size() const { return attributes_.size(); } + + bool empty() const { return attributes_.empty(); } + + bool operator==(const UnknownAttributeSet& other) const { + return this == &other || attributes_ == other.attributes_; + } + + bool operator!=(const UnknownAttributeSet& other) const { + return !operator==(other); + } static UnknownAttributeSet Merge(const UnknownAttributeSet& set1, const UnknownAttributeSet& set2) { @@ -44,20 +74,19 @@ class UnknownAttributeSet { } private: - void Add(const CelAttribute* attribute) { - if (!attribute) { - return; - } - for (auto attr : attributes_) { - if (*attr == *attribute) { - return; - } + friend class AttributeUtility; + friend class UnknownSet; + + void Add(const CelAttribute& attribute) { attributes_.insert(attribute); } + + void Add(const UnknownAttributeSet& other) { + for (const auto& attribute : other) { + Add(attribute); } - attributes_.push_back(attribute); } // Attribute container. - std::vector attributes_; + Container attributes_; }; } // namespace runtime diff --git a/eval/public/unknown_attribute_set_test.cc b/eval/public/unknown_attribute_set_test.cc index a2113ed69..a90f7124f 100644 --- a/eval/public/unknown_attribute_set_test.cc +++ b/eval/public/unknown_attribute_set_test.cc @@ -33,9 +33,9 @@ TEST(UnknownAttributeSetTest, TestCreate) { CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - UnknownAttributeSet unknown_set({cel_attr.get()}); - EXPECT_THAT(unknown_set.attributes().size(), Eq(1)); - EXPECT_THAT(*(unknown_set.attributes()[0]), Eq(*cel_attr)); + UnknownAttributeSet unknown_set({*cel_attr}); + EXPECT_THAT(unknown_set.size(), Eq(1)); + EXPECT_THAT(*(unknown_set.begin()), Eq(*cel_attr)); } TEST(UnknownAttributeSetTest, TestMergeSets) { @@ -46,47 +46,47 @@ TEST(UnknownAttributeSetTest, TestMergeSets) { const std::string kAttr2 = "a2"; const std::string kAttr3 = "a3"; - std::shared_ptr cel_attr1 = std::make_shared( + CelAttribute cel_attr1( expr, std::vector( {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), CelAttributeQualifier::Create(CelValue::CreateInt64(1)), CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - std::shared_ptr cel_attr1_copy = std::make_shared( + CelAttribute cel_attr1_copy( expr, std::vector( {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), CelAttributeQualifier::Create(CelValue::CreateInt64(1)), CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - std::shared_ptr cel_attr2 = std::make_shared( + CelAttribute cel_attr2( expr, std::vector( {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), CelAttributeQualifier::Create(CelValue::CreateInt64(2)), CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - std::shared_ptr cel_attr3 = std::make_shared( + CelAttribute cel_attr3( expr, std::vector( {CelAttributeQualifier::Create(CelValue::CreateString(&kAttr1)), CelAttributeQualifier::Create(CelValue::CreateInt64(2)), CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(false))})); - UnknownAttributeSet unknown_set1({cel_attr1.get(), cel_attr2.get()}); - UnknownAttributeSet unknown_set2({cel_attr1_copy.get(), cel_attr3.get()}); + UnknownAttributeSet unknown_set1({cel_attr1, cel_attr2}); + UnknownAttributeSet unknown_set2({cel_attr1_copy, cel_attr3}); UnknownAttributeSet unknown_set3 = UnknownAttributeSet::Merge(unknown_set1, unknown_set2); - EXPECT_THAT(unknown_set3.attributes().size(), Eq(3)); + EXPECT_THAT(unknown_set3.size(), Eq(3)); std::vector attrs1; - for (auto attr_ptr : unknown_set3.attributes()) { - attrs1.push_back(*attr_ptr); + for (const auto& attr_ptr : unknown_set3) { + attrs1.push_back(attr_ptr); } - std::vector attrs2 = {*cel_attr1, *cel_attr2, *cel_attr3}; + std::vector attrs2 = {cel_attr1, cel_attr2, cel_attr3}; EXPECT_THAT(attrs1, testing::UnorderedPointwise(Eq(), attrs2)); } diff --git a/eval/public/unknown_function_result_set.cc b/eval/public/unknown_function_result_set.cc index 75361c263..ccce69dcd 100644 --- a/eval/public/unknown_function_result_set.cc +++ b/eval/public/unknown_function_result_set.cc @@ -1,78 +1,16 @@ #include "eval/public/unknown_function_result_set.h" -#include - -#include "absl/container/btree_set.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_options.h" -#include "eval/public/cel_value.h" -#include "eval/public/set_util.h" - namespace google { namespace api { namespace expr { namespace runtime { -namespace { - -// Tests that lhs descriptor is less than (name, receiver call style, -// arg types). -// Argument type Any is not treated specially. For example: -// {"f", false, {kAny}} > {"f", false, {kInt64}} -bool DescriptorLessThan(const CelFunctionDescriptor& lhs, - const CelFunctionDescriptor& rhs) { - if (lhs.name() < rhs.name()) { - return true; - } - if (lhs.name() > rhs.name()) { - return false; - } - - if (lhs.receiver_style() < rhs.receiver_style()) { - return true; - } - if (lhs.receiver_style() > rhs.receiver_style()) { - return false; - } - - if (lhs.types() >= rhs.types()) { - return false; - } - - return true; -} - -bool UnknownFunctionResultLessThan(const UnknownFunctionResult& lhs, - const UnknownFunctionResult& rhs) { - if (DescriptorLessThan(lhs.descriptor(), rhs.descriptor())) { - return true; - } - if (DescriptorLessThan(rhs.descriptor(), lhs.descriptor())) { - return false; - } - - // equal - return false; -} - -} // namespace - -bool UnknownFunctionComparator::operator()( - const UnknownFunctionResult* lhs, const UnknownFunctionResult* rhs) const { - return UnknownFunctionResultLessThan(*lhs, *rhs); -} - -bool UnknownFunctionResult::IsEqualTo( - const UnknownFunctionResult& other) const { - return !(UnknownFunctionResultLessThan(*this, other) || - UnknownFunctionResultLessThan(other, *this)); -} // Implementation for merge constructor. UnknownFunctionResultSet::UnknownFunctionResultSet( const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs) - : unknown_function_results_(lhs.unknown_function_results()) { - for (const UnknownFunctionResult* call : rhs.unknown_function_results()) { - unknown_function_results_.insert(call); + : function_results_(lhs.function_results_) { + for (const auto& function_result : rhs) { + function_results_.insert(function_result); } } diff --git a/eval/public/unknown_function_result_set.h b/eval/public/unknown_function_result_set.h index ed13c3985..61f943e13 100644 --- a/eval/public/unknown_function_result_set.h +++ b/eval/public/unknown_function_result_set.h @@ -1,10 +1,10 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ +#include #include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/container/btree_map.h" #include "absl/container/btree_set.h" #include "eval/public/cel_function.h" @@ -15,10 +15,16 @@ namespace runtime { // Represents a function result that is unknown at the time of execution. This // allows for lazy evaluation of expensive functions. -class UnknownFunctionResult { +class UnknownFunctionResult final { public: - UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id) - : descriptor_(descriptor), expr_id_(expr_id) {} + UnknownFunctionResult() = default; + UnknownFunctionResult(const UnknownFunctionResult&) = default; + UnknownFunctionResult(UnknownFunctionResult&&) = default; + UnknownFunctionResult& operator=(const UnknownFunctionResult&) = default; + UnknownFunctionResult& operator=(UnknownFunctionResult&&) = default; + + UnknownFunctionResult(CelFunctionDescriptor descriptor, int64_t expr_id) + : descriptor_(std::move(descriptor)), expr_id_(expr_id) {} // The descriptor of the called function that return Unknown. const CelFunctionDescriptor& descriptor() const { return descriptor_; } @@ -31,7 +37,9 @@ class UnknownFunctionResult { // Equality operator provided for testing. Compatible with set less-than // comparator. // Compares descriptor then arguments elementwise. - bool IsEqualTo(const UnknownFunctionResult& other) const; + bool IsEqualTo(const UnknownFunctionResult& other) const { + return descriptor() == other.descriptor(); + } // TODO(issues/5): re-implement argument capture @@ -40,38 +48,86 @@ class UnknownFunctionResult { int64_t expr_id_; }; -// Comparator for set semantics. -struct UnknownFunctionComparator { - bool operator()(const UnknownFunctionResult*, - const UnknownFunctionResult*) const; -}; +inline bool operator==(const UnknownFunctionResult& lhs, + const UnknownFunctionResult& rhs) { + return lhs.IsEqualTo(rhs); +} + +inline bool operator<(const UnknownFunctionResult& lhs, + const UnknownFunctionResult& rhs) { + return lhs.descriptor() < rhs.descriptor(); +} + +class AttributeUtility; +class UnknownSet; // Represents a collection of unknown function results at a particular point in // execution. Execution should advance further if this set of unknowns are // provided. It may not advance if only a subset are provided. // Set semantics use |IsEqualTo()| defined on |UnknownFunctionResult|. -class UnknownFunctionResultSet { +class UnknownFunctionResultSet final { + private: + using Container = absl::btree_set; + public: - // Empty set - UnknownFunctionResultSet() {} + using value_type = typename Container::value_type; + using size_type = typename Container::size_type; + using iterator = typename Container::const_iterator; + using const_iterator = typename Container::const_iterator; + + UnknownFunctionResultSet() = default; + UnknownFunctionResultSet(const UnknownFunctionResultSet&) = default; + UnknownFunctionResultSet(UnknownFunctionResultSet&&) = default; + UnknownFunctionResultSet& operator=(const UnknownFunctionResultSet&) = + default; + UnknownFunctionResultSet& operator=(UnknownFunctionResultSet&&) = default; // Merge constructor -- effectively union(lhs, rhs). UnknownFunctionResultSet(const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs); // Initialize with a single UnknownFunctionResult. - UnknownFunctionResultSet(const UnknownFunctionResult* initial) - : unknown_function_results_{initial} {} + explicit UnknownFunctionResultSet(UnknownFunctionResult initial) + : function_results_{std::move(initial)} {} + + UnknownFunctionResultSet(std::initializer_list il) + : function_results_(il) {} + + iterator begin() const { return function_results_.begin(); } + + const_iterator cbegin() const { return function_results_.cbegin(); } - using Container = - absl::btree_set; + iterator end() const { return function_results_.end(); } - const Container& unknown_function_results() const { - return unknown_function_results_; + const_iterator cend() const { return function_results_.cend(); } + + size_type size() const { return function_results_.size(); } + + bool empty() const { return function_results_.empty(); } + + bool operator==(const UnknownFunctionResultSet& other) const { + return this == &other || function_results_ == other.function_results_; + } + + bool operator!=(const UnknownFunctionResultSet& other) const { + return !operator==(other); } private: - Container unknown_function_results_; + friend class AttributeUtility; + friend class UnknownSet; + + void Add(const UnknownFunctionResult& function_result) { + function_results_.insert(function_result); + } + + void Add(const UnknownFunctionResultSet& other) { + for (const auto& function_result : other) { + Add(function_result); + } + } + + Container function_results_; }; } // namespace runtime diff --git a/eval/public/unknown_function_result_set_test.cc b/eval/public/unknown_function_result_set_test.cc index a4005a54c..f2da7b475 100644 --- a/eval/public/unknown_function_result_set_test.cc +++ b/eval/public/unknown_function_result_set_test.cc @@ -37,26 +37,23 @@ CelFunctionDescriptor kTwoInt("TwoInt", false, CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64}); -// Helper to confirm the set comparator works. -bool IsLessThan(const UnknownFunctionResult& lhs, - const UnknownFunctionResult& rhs) { - return UnknownFunctionComparator()(&lhs, &rhs); -} - TEST(UnknownFunctionResult, Equals) { UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); UnknownFunctionResult call2(kTwoInt, /*expr_id=*/0); EXPECT_TRUE(call1.IsEqualTo(call2)); - EXPECT_FALSE(IsLessThan(call1, call2)); - EXPECT_FALSE(IsLessThan(call2, call1)); UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); UnknownFunctionResult call4(kOneInt, /*expr_id=*/0); EXPECT_TRUE(call3.IsEqualTo(call4)); + + UnknownFunctionResultSet call_set({call1, call3}); + EXPECT_EQ(call_set.size(), 2); + EXPECT_EQ(*call_set.begin(), call3); + EXPECT_EQ(*(++call_set.begin()), call1); } TEST(UnknownFunctionResult, InequalDescriptor) { @@ -65,7 +62,6 @@ TEST(UnknownFunctionResult, InequalDescriptor) { UnknownFunctionResult call2(kOneInt, /*expr_id=*/0); EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call2, call1)); CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); @@ -74,7 +70,13 @@ TEST(UnknownFunctionResult, InequalDescriptor) { UnknownFunctionResult call4(one_uint, /*expr_id=*/0); EXPECT_FALSE(call3.IsEqualTo(call4)); - EXPECT_TRUE(IsLessThan(call3, call4)); + + UnknownFunctionResultSet call_set({call1, call3, call4}); + EXPECT_EQ(call_set.size(), 3); + auto it = call_set.begin(); + EXPECT_EQ(*it++, call3); + EXPECT_EQ(*it++, call4); + EXPECT_EQ(*it++, call1); } } // namespace diff --git a/eval/public/unknown_set.h b/eval/public/unknown_set.h index 3b7168afe..50607bf5d 100644 --- a/eval/public/unknown_set.h +++ b/eval/public/unknown_set.h @@ -9,6 +9,8 @@ namespace api { namespace expr { namespace runtime { +class AttributeUtility; + // Class representing a collection of unknowns from a single evaluation pass of // a CEL expression. class UnknownSet { @@ -39,7 +41,22 @@ class UnknownSet { return unknown_function_results_; } + bool operator==(const UnknownSet& other) const { + return this == &other || + (unknown_attributes_ == other.unknown_attributes_ && + unknown_function_results_ == other.unknown_function_results_); + } + + bool operator!=(const UnknownSet& other) const { return !operator==(other); } + private: + friend class AttributeUtility; + + void Add(const UnknownSet& other) { + unknown_attributes_.Add(other.unknown_attributes_); + unknown_function_results_.Add(other.unknown_function_results_); + } + UnknownAttributeSet unknown_attributes_; UnknownFunctionResultSet unknown_function_results_; }; diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc index 0a9cafdf6..1b67dfa38 100644 --- a/eval/public/unknown_set_test.cc +++ b/eval/public/unknown_set_test.cc @@ -19,9 +19,7 @@ using testing::UnorderedElementsAre; UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); - const auto* function_result = - Arena::Create(arena, desc, /*expr_id=*/0); - return UnknownFunctionResultSet(function_result); + return UnknownFunctionResultSet(UnknownFunctionResult(desc, /*expr_id=*/0)); } UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { @@ -31,16 +29,15 @@ UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { std::vector attr_trail{ CelAttributeQualifier::Create(CelValue::CreateInt64(id))}; - const auto* attr = Arena::Create(arena, expr, attr_trail); - return UnknownAttributeSet({attr}); + return UnknownAttributeSet({CelAttribute(expr, std::move(attr_trail))}); } MATCHER_P(UnknownAttributeIs, id, "") { - const CelAttribute* attr = arg; - if (attr->qualifier_path().size() != 1) { + const CelAttribute& attr = arg; + if (attr.qualifier_path().size() != 1) { return false; } - auto maybe_qualifier = attr->qualifier_path()[0].GetInt64Key(); + auto maybe_qualifier = attr.qualifier_path()[0].GetInt64Key(); if (!maybe_qualifier.has_value()) { return false; } @@ -56,18 +53,17 @@ TEST(UnknownSet, AttributesMerge) { UnknownSet e(c, d); EXPECT_THAT( - d.unknown_attributes().attributes(), + d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); EXPECT_THAT( - e.unknown_attributes().attributes(), + e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } TEST(UnknownSet, DefaultEmpty) { UnknownSet empty_set; - EXPECT_THAT(empty_set.unknown_attributes().attributes(), IsEmpty()); - EXPECT_THAT(empty_set.unknown_function_results().unknown_function_results(), - IsEmpty()); + EXPECT_THAT(empty_set.unknown_attributes(), IsEmpty()); + EXPECT_THAT(empty_set.unknown_function_results(), IsEmpty()); } TEST(UnknownSet, MixedMerges) { @@ -79,10 +75,10 @@ TEST(UnknownSet, MixedMerges) { UnknownSet d(a, b); UnknownSet e(c, d); - EXPECT_THAT(d.unknown_attributes().attributes(), + EXPECT_THAT(d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1))); EXPECT_THAT( - e.unknown_attributes().attributes(), + e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 5e4d969c2..aa809bec7 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -162,13 +162,13 @@ class UnknownsTest : public testing::Test { }; MATCHER_P(FunctionCallIs, fn_name, "") { - const UnknownFunctionResult* result = arg; - return result->descriptor().name() == fn_name; + const UnknownFunctionResult& result = arg; + return result.descriptor().name() == fn_name; } MATCHER_P(AttributeIs, attr, "") { - const CelAttribute* result = arg; - return result->variable_name() == attr; + const CelAttribute& result = arg; + return result.variable_name() == attr; } TEST_F(UnknownsTest, NoUnknowns) { @@ -211,7 +211,7 @@ TEST_F(UnknownsTest, UnknownAttributes) { CelValue response = maybe_response.value(); ASSERT_TRUE(response.IsUnknownSet()); - EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var1"))); } @@ -275,9 +275,7 @@ TEST_F(UnknownsTest, UnknownFunctions) { CelValue response = maybe_response.value(); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), ElementsAre(FunctionCallIs("F1"))); } @@ -300,11 +298,9 @@ TEST_F(UnknownsTest, UnknownsMerge) { CelValue response = maybe_response.value(); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), ElementsAre(FunctionCallIs("F1"))); - EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var2"))); } @@ -452,9 +448,7 @@ TEST_F(UnknownsCompTest, UnknownsMerge) { CelValue response = eval_status.value(); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), testing::SizeIs(1)); } @@ -589,9 +583,7 @@ TEST_F(UnknownsCompCondTest, UnknownConditionReturned) { ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); // The comprehension ends on the first non-bool condition, so we only get one // call captured in the UnknownSet. - EXPECT_THAT(response.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), + EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), testing::SizeIs(1)); } @@ -710,13 +702,12 @@ TEST(UnknownsIterAttrTest, IterAttributeTrail) { CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1]' is partially unknown when we make the function call so we treat it // as unknown. ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 1); @@ -761,7 +752,7 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypes) { CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ(response.UnknownSetOrDie(), &unknown_set); + ASSERT_EQ(*response.UnknownSetOrDie(), unknown_set); } TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { @@ -901,13 +892,12 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailMap) { CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].key' is unknown when we make the Fn function call. // comprehension is: ((([] + false) + unk) + false) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); @@ -1019,13 +1009,12 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].value_key' is unknown when we make the cons function call. // comprehension is: ((([] + [1]) + unk) + [1]) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); @@ -1085,11 +1074,10 @@ TEST(UnknownsIterAttrTest, IterAttributeTrailFilterConditions) { // loop2: (true)? unk{1} + [1] : unk{1} -> unk{1} // result: unk{1} ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); - ASSERT_EQ( - response.UnknownSetOrDie()->unknown_attributes().attributes().size(), 1); + ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() - .attributes()[0] + .begin() ->qualifier_path() .size(), 2); From 88cdde244f32e11e73bc6ec1e2f713be7a7b32c5 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 Aug 2022 19:24:42 +0000 Subject: [PATCH 023/303] Internal change PiperOrigin-RevId: 466433737 --- base/kind.h | 17 ++++++++++------- eval/public/BUILD | 1 + 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/base/kind.h b/base/kind.h index 66ecb2016..5f3f8d65a 100644 --- a/base/kind.h +++ b/base/kind.h @@ -22,25 +22,28 @@ namespace cel { -enum class Kind : uint8_t { +enum class Kind /* : uint8_t */ { + // Must match legacy CelValue::Type. kNullType = 0, - kError, - kDyn, - kAny, - kType, kBool, kInt, kUint, kDouble, kString, kBytes, - kEnum, + kStruct, kDuration, kTimestamp, kList, kMap, - kStruct, kUnknown, + kType, + kError, + kAny, + + // New kinds not present in legacy CelValue. + kEnum, + kDyn, // INTERNAL: Do not exceed 127. Implementation details rely on the fact that // we can store `Kind` using 7 bits. diff --git a/eval/public/BUILD b/eval/public/BUILD index de25a91d9..01795cc75 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -587,6 +587,7 @@ cc_test( ":cel_value_internal", ":unknown_attribute_set", ":unknown_set", + "//base:kind", "//base:memory_manager", "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:trivial_legacy_type_info", From 3fca4138afa285ac7e6a9a532d6909f7fc388980 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 Aug 2022 22:09:50 +0000 Subject: [PATCH 024/303] Internal change PiperOrigin-RevId: 466481149 --- base/kind.h | 7 +++++++ eval/public/BUILD | 1 + eval/public/cel_value.h | 8 ++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/base/kind.h b/base/kind.h index 5f3f8d65a..08690b89f 100644 --- a/base/kind.h +++ b/base/kind.h @@ -45,6 +45,13 @@ enum class Kind /* : uint8_t */ { kEnum, kDyn, + // Legacy aliases, deprecated do not use. + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + // INTERNAL: Do not exceed 127. Implementation details rely on the fact that // we can store `Kind` using 7 bits. }; diff --git a/eval/public/BUILD b/eval/public/BUILD index 01795cc75..06893aa91 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -71,6 +71,7 @@ cc_library( deps = [ ":cel_value_internal", ":message_wrapper", + "//base:kind", "//base:memory_manager", "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 173d59db7..a9c7c8eb9 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -32,6 +32,7 @@ #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "base/kind.h" #include "base/memory_manager.h" #include "eval/public/cel_value_internal.h" #include "eval/public/message_wrapper.h" @@ -136,7 +137,10 @@ class CelValue { // Enum for types supported. // This is not recommended for use in exhaustive switches in client code. // Types may be updated over time. - enum class Type { + using Type = ::cel::Kind; + + // Legacy enumeration that is here for testing purposes. Do not use. + enum class LegacyType { kNullType = IndexOf::value, kBool = IndexOf::value, kInt64 = IndexOf::value, @@ -160,7 +164,7 @@ class CelValue { CelValue() : CelValue(NullType()) {} // Returns Type that describes the type of value stored. - Type type() const { return Type(value_.index()); } + Type type() const { return static_cast(value_.index()); } // Returns debug string describing a value const std::string DebugString() const; From 0915d16a453a4dfe40656ffdbe2b38672c152dff Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 10 Aug 2022 14:39:25 +0000 Subject: [PATCH 025/303] Add missing Bytes field from Constant, and provide missing setters. PiperOrigin-RevId: 466680535 --- base/BUILD | 3 +-- base/ast.h | 53 ++++++++++++++++++++++++++++++++++++-- base/ast_test.cc | 66 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 5 deletions(-) diff --git a/base/BUILD b/base/BUILD index c012c81c3..069c4137a 100644 --- a/base/BUILD +++ b/base/BUILD @@ -311,8 +311,7 @@ cc_test( deps = [ ":ast", "//internal:testing", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/time", ], ) diff --git a/base/ast.h b/base/ast.h index 4c8322f31..caac07da4 100644 --- a/base/ast.h +++ b/base/ast.h @@ -31,6 +31,13 @@ namespace cel::ast::internal { enum class NullValue { kNullValue = 0 }; +// A holder class to differentiate between CEL string and CEL bytes constants. +struct Bytes { + std::string bytes; + + bool operator==(const Bytes& other) const { return bytes == other.bytes; } +}; + // Represents a primitive literal. // // This is similar as the primitives supported in the well-known type @@ -48,8 +55,9 @@ enum class NullValue { kNullValue = 0 }; // message that can hold any constant object representation supplied or // produced at evaluation time. // --) -using ConstantKind = absl::variant; +using ConstantKind = + absl::variant; class Constant { public: @@ -77,6 +85,8 @@ class Constant { return NullValue::kNullValue; } + void set_null_value(NullValue null_value) { constant_kind_ = null_value; } + bool has_bool_value() const { return absl::holds_alternative(constant_kind_); } @@ -89,6 +99,8 @@ class Constant { return false; } + void set_bool_value(bool bool_value) { constant_kind_ = bool_value; } + bool has_int64_value() const { return absl::holds_alternative(constant_kind_); } @@ -101,6 +113,8 @@ class Constant { return 0; } + void set_int64_value(int64_t int64_value) { constant_kind_ = int64_value; } + bool has_uint64_value() const { return absl::holds_alternative(constant_kind_); } @@ -113,6 +127,10 @@ class Constant { return 0; } + void set_uint64_value(uint64_t uint64_value) { + constant_kind_ = uint64_value; + } + bool has_double_value() const { return absl::holds_alternative(constant_kind_); } @@ -125,6 +143,8 @@ class Constant { return 0; } + void set_double_value(double double_value) { constant_kind_ = double_value; } + bool has_string_value() const { return absl::holds_alternative(constant_kind_); } @@ -138,10 +158,31 @@ class Constant { return *default_string_value_; } + void set_string_value(std::string string_value) { + constant_kind_ = string_value; + } + + const std::string& bytes_value() const { + auto* value = absl::get_if(&constant_kind_); + if (value != nullptr) { + return value->bytes; + } + static std::string* default_string_value_ = new std::string(""); + return *default_string_value_; + } + + void set_bytes_value(std::string bytes_value) { + constant_kind_ = Bytes{std::move(bytes_value)}; + } + bool has_duration_value() const { return absl::holds_alternative(constant_kind_); } + void set_duration_value(absl::Duration duration_value) { + constant_kind_ = std::move(duration_value); + } + const absl::Duration& duration_value() const { auto* value = absl::get_if(&constant_kind_); if (value != nullptr) { @@ -164,6 +205,10 @@ class Constant { return default_time_; } + void set_time_value(absl::Time time_value) { + constant_kind_ = std::move(time_value); + } + bool operator==(const Constant& other) const { return constant_kind_ == other.constant_kind_; } @@ -369,6 +414,10 @@ class CreateStruct { return *default_field_key; } + void set_field_key(std::string field_key) { + key_kind_ = std::move(field_key); + } + const Expr& map_key() const; Expr& mutable_map_key() { diff --git a/base/ast_test.cc b/base/ast_test.cc index a2d188188..a1d1722af 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -17,7 +17,7 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/time/time.h" #include "internal/testing.h" namespace cel { @@ -32,6 +32,63 @@ TEST(AstTest, ExprConstructionConstant) { ASSERT_TRUE(constant.bool_value()); } +TEST(AstTest, ConstantNullValueSetterGetterTest) { + Constant constant; + constant.set_null_value(NullValue::kNullValue); + EXPECT_EQ(constant.null_value(), NullValue::kNullValue); +} + +TEST(AstTest, ConstantBoolValueSetterGetterTest) { + Constant constant; + constant.set_bool_value(true); + EXPECT_TRUE(constant.bool_value()); + constant.set_bool_value(false); + EXPECT_FALSE(constant.bool_value()); +} + +TEST(AstTest, ConstantInt64ValueSetterGetterTest) { + Constant constant; + constant.set_int64_value(-1234); + EXPECT_EQ(constant.int64_value(), -1234); +} + +TEST(AstTest, ConstantUint64ValueSetterGetterTest) { + Constant constant; + constant.set_uint64_value(1234); + EXPECT_EQ(constant.uint64_value(), 1234); +} + +TEST(AstTest, ConstantDoubleValueSetterGetterTest) { + Constant constant; + constant.set_double_value(12.34); + EXPECT_EQ(constant.double_value(), 12.34); +} + +TEST(AstTest, ConstantStringValueSetterGetterTest) { + Constant constant; + constant.set_string_value("test"); + EXPECT_EQ(constant.string_value(), "test"); +} + +TEST(AstTest, ConstantBytesValueSetterGetterTest) { + Constant constant; + constant.set_string_value("test"); + EXPECT_EQ(constant.string_value(), "test"); +} + +TEST(AstTest, ConstantDurationValueSetterGetterTest) { + Constant constant; + constant.set_duration_value(absl::Seconds(10)); + EXPECT_EQ(constant.duration_value(), absl::Seconds(10)); +} + +TEST(AstTest, ConstantTimeValueSetterGetterTest) { + Constant constant; + auto time = absl::UnixEpoch() + absl::Seconds(10); + constant.set_time_value(time); + EXPECT_EQ(constant.time_value(), time); +} + TEST(AstTest, ConstantDefaults) { Constant constant; EXPECT_EQ(constant.null_value(), NullValue::kNullValue); @@ -40,6 +97,7 @@ TEST(AstTest, ConstantDefaults) { EXPECT_EQ(constant.uint64_value(), 0); EXPECT_EQ(constant.double_value(), 0); EXPECT_TRUE(constant.string_value().empty()); + EXPECT_TRUE(constant.bytes_value().empty()); EXPECT_EQ(constant.duration_value(), absl::Duration()); EXPECT_EQ(constant.time_value(), absl::UnixEpoch()); } @@ -182,6 +240,12 @@ TEST(AstTest, CreateStructEntryMutableMapKey) { ASSERT_EQ(absl::get(entry.map_key().expr_kind()).name(), "new_key"); } +TEST(AstTest, CreateStructEntryFieldKeyGetterSetterTest) { + CreateStruct::Entry entry; + entry.set_field_key("key"); + EXPECT_EQ(entry.field_key(), "key"); +} + TEST(AstTest, ExprConstructionComprehension) { Comprehension comprehension; comprehension.set_iter_var("iter_var"); From 0039f4a2ee3767fb45c841a273faa15643b383cb Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 11 Aug 2022 02:00:31 +0000 Subject: [PATCH 026/303] Fix CreateStruct::Entry comparator to compare the value of the key instead of the pointer. PiperOrigin-RevId: 466839120 --- base/ast.cc | 9 +++++-- base/ast.h | 4 ++- base/ast_test.cc | 64 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/base/ast.cc b/base/ast.cc index 7ae327bdd..1221e50bc 100644 --- a/base/ast.cc +++ b/base/ast.cc @@ -65,8 +65,13 @@ const Expr& CreateStruct::Entry::value() const { } bool CreateStruct::Entry::operator==(const Entry& other) const { - return id_ == other.id_ && key_kind_ == other.key_kind_ && - value() == other.value(); + bool has_same_key = false; + if (has_field_key() && other.has_field_key()) { + has_same_key = field_key() == other.field_key(); + } else if (has_map_key() && other.has_map_key()) { + has_same_key = map_key() == other.map_key(); + } + return id_ == other.id_ && has_same_key && value() == other.value(); } const Expr& Comprehension::iter_range() const { diff --git a/base/ast.h b/base/ast.h index caac07da4..102fd8ced 100644 --- a/base/ast.h +++ b/base/ast.h @@ -442,11 +442,13 @@ class CreateStruct { bool operator==(const Entry& other) const; + bool operator!=(const Entry& other) const { return !operator==(other); } + private: // Required. An id assigned to this node by the parser which is unique // in a given expression tree. This is used to associate type // information and other attributes to the node. - int64_t id_; + int64_t id_ = 0; // The `Entry` key kinds. KeyKind key_kind_; // Required. The value assigned to the key. diff --git a/base/ast_test.cc b/base/ast_test.cc index a1d1722af..e56ae315c 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -246,6 +246,70 @@ TEST(AstTest, CreateStructEntryFieldKeyGetterSetterTest) { EXPECT_EQ(entry.field_key(), "key"); } +TEST(AstTest, CreateStructEntryComparatorMapKeySuccess) { + CreateStruct::Entry entry1; + entry1.mutable_map_key().set_expr_kind(Ident("key")); + CreateStruct::Entry entry2; + entry2.mutable_map_key().set_expr_kind(Ident("key")); + EXPECT_EQ(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorMapKeyFailure) { + CreateStruct::Entry entry1; + entry1.mutable_map_key().set_expr_kind(Ident("key")); + CreateStruct::Entry entry2; + entry2.mutable_map_key().set_expr_kind(Ident("other_key")); + EXPECT_NE(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorFieldKeySuccess) { + CreateStruct::Entry entry1; + entry1.set_field_key("key"); + CreateStruct::Entry entry2; + entry2.set_field_key("key"); + EXPECT_EQ(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorFieldKeyFailure) { + CreateStruct::Entry entry1; + entry1.set_field_key("key"); + CreateStruct::Entry entry2; + entry2.set_field_key("other_key"); + EXPECT_NE(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorFieldKeyDiffersFromMapKey) { + CreateStruct::Entry entry1; + entry1.set_field_key(""); + CreateStruct::Entry entry2; + entry2.mutable_map_key(); + EXPECT_NE(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorMapKeyDiffersFromFieldKey) { + CreateStruct::Entry entry1; + entry1.mutable_map_key(); + CreateStruct::Entry entry2; + entry2.set_field_key(""); + EXPECT_NE(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorValueSuccess) { + CreateStruct::Entry entry1; + entry1.mutable_value().set_expr_kind(Ident("key")); + CreateStruct::Entry entry2; + entry2.mutable_value().set_expr_kind(Ident("key")); + EXPECT_EQ(entry1, entry2); +} + +TEST(AstTest, CreateStructEntryComparatorValueFailure) { + CreateStruct::Entry entry1; + entry1.mutable_value().set_expr_kind(Ident("key")); + CreateStruct::Entry entry2; + entry2.mutable_value().set_expr_kind(Ident("other_key")); + EXPECT_NE(entry1, entry2); +} + TEST(AstTest, ExprConstructionComprehension) { Comprehension comprehension; comprehension.set_iter_var("iter_var"); From cd4245907c42fdd5370d88988a0bc8032a84a413 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 11 Aug 2022 02:28:03 +0000 Subject: [PATCH 027/303] Convert FlatExprBuilder internals to use native types PiperOrigin-RevId: 466843298 --- base/BUILD | 5 +- base/ast.cc | 4 + base/ast.h | 33 +- base/ast_test.cc | 5 + base/ast_utility.cc | 134 ++-- base/ast_utility.h | 2 + base/ast_utility_test.cc | 19 +- eval/compiler/BUILD | 28 +- eval/compiler/constant_folding.cc | 218 ++++--- eval/compiler/constant_folding.h | 18 +- eval/compiler/constant_folding_test.cc | 88 +-- eval/compiler/flat_expr_builder.cc | 463 ++++++++------ eval/compiler/flat_expr_builder_test.cc | 36 +- eval/compiler/qualified_reference_resolver.cc | 114 ++-- eval/compiler/qualified_reference_resolver.h | 10 +- .../qualified_reference_resolver_test.cc | 491 +++++++------- eval/eval/BUILD | 7 + eval/eval/comprehension_step_test.cc | 29 +- eval/eval/const_value_step.cc | 71 +- eval/eval/const_value_step.h | 3 +- eval/eval/const_value_step_test.cc | 71 +- eval/eval/container_access_step.cc | 4 +- eval/eval/container_access_step.h | 2 +- eval/eval/container_access_step_test.cc | 58 +- eval/eval/create_list_step.cc | 10 +- eval/eval/create_list_step.h | 6 +- eval/eval/create_list_step_test.cc | 41 +- eval/eval/create_struct_step.cc | 8 +- eval/eval/create_struct_step.h | 4 +- eval/eval/create_struct_step_test.cc | 58 +- eval/eval/evaluator_core.h | 27 +- eval/eval/evaluator_core_test.cc | 2 +- eval/eval/function_step.cc | 18 +- eval/eval/function_step.h | 4 +- eval/eval/function_step_test.cc | 254 ++++---- eval/eval/ident_step.cc | 4 +- eval/eval/ident_step.h | 3 +- eval/eval/ident_step_test.cc | 35 +- eval/eval/logic_step_test.cc | 35 +- eval/eval/select_step.cc | 4 +- eval/eval/select_step.h | 3 +- eval/eval/select_step_test.cc | 96 +-- eval/eval/shadowable_value_step_test.cc | 2 +- eval/eval/ternary_step_test.cc | 32 +- eval/public/BUILD | 34 + eval/public/ast_rewrite_native.cc | 404 ++++++++++++ eval/public/ast_rewrite_native.h | 155 +++++ eval/public/ast_rewrite_native_test.cc | 604 ++++++++++++++++++ eval/public/ast_traverse_native.cc | 5 + eval/tests/benchmark_test.cc | 14 +- 50 files changed, 2624 insertions(+), 1151 deletions(-) create mode 100644 eval/public/ast_rewrite_native.cc create mode 100644 eval/public/ast_rewrite_native.h create mode 100644 eval/public/ast_rewrite_native_test.cc diff --git a/base/BUILD b/base/BUILD index 069c4137a..458af29bc 100644 --- a/base/BUILD +++ b/base/BUILD @@ -296,8 +296,10 @@ cc_library( "ast.h", ], deps = [ - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", ], @@ -326,6 +328,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/base/ast.cc b/base/ast.cc index 1221e50bc..3c71395d7 100644 --- a/base/ast.cc +++ b/base/ast.cc @@ -14,7 +14,11 @@ #include "base/ast.h" +#include #include +#include +#include +#include namespace cel::ast::internal { diff --git a/base/ast.h b/base/ast.h index 102fd8ced..46b782fe9 100644 --- a/base/ast.h +++ b/base/ast.h @@ -162,6 +162,10 @@ class Constant { constant_kind_ = string_value; } + bool has_bytes_value() const { + return absl::holds_alternative(constant_kind_); + } + const std::string& bytes_value() const { auto* value = absl::get_if(&constant_kind_); if (value != nullptr) { @@ -658,8 +662,12 @@ class Comprehension { std::unique_ptr result_; }; -using ExprKind = absl::variant; +// Even though, the Expr proto does not allow for an unset, macro calls in the +// way they are used today sometimes elide parts of the AST if its +// unchanged/uninteresting. +using ExprKind = + absl::variant; // Analogous to google::api::expr::v1alpha1::Expr // An abstract representation of a common expression. @@ -1439,6 +1447,8 @@ class Type { // Describes a resolved reference to a declaration. class Reference { public: + Reference() {} + Reference(std::string name, std::vector overload_id, Constant value) : name_(std::move(name)), @@ -1457,11 +1467,24 @@ class Reference { const std::vector& overload_id() const { return overload_id_; } - const Constant& value() const { return value_; } + const Constant& value() const { + if (value_.has_value()) { + return value_.value(); + } + static const Constant* default_constant = new Constant; + return *default_constant; + } std::vector& mutable_overload_id() { return overload_id_; } - Constant& mutable_value() { return value_; } + Constant& mutable_value() { + if (!value_.has_value()) { + value_.emplace(); + } + return *value_; + } + + bool has_value() const { return value_.has_value(); } private: // The fully qualified name of the declaration. @@ -1477,7 +1500,7 @@ class Reference { std::vector overload_id_; // For references to constants, this may contain the value of the // constant if known at compile time. - Constant value_; + absl::optional value_; }; // Analogous to google::api::expr::v1alpha1::CheckedExpr diff --git a/base/ast_test.cc b/base/ast_test.cc index e56ae315c..d8db3f71a 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -553,6 +553,11 @@ TEST(AstTest, ExprMutableConstruction) { EXPECT_EQ(expr.comprehension_expr().accu_var(), "accu_var"); } +TEST(AstTest, ReferenceConstantDefaultValue) { + Reference reference; + EXPECT_EQ(reference.value(), Constant()); +} + } // namespace } // namespace internal } // namespace ast diff --git a/base/ast_utility.cc b/base/ast_utility.cc index a4cd54691..bb2673122 100644 --- a/base/ast_utility.cc +++ b/base/ast_utility.cc @@ -29,6 +29,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "absl/types/variant.h" #include "base/ast.h" namespace cel::ast::internal { @@ -48,7 +49,7 @@ absl::StatusOr ToNative(const google::api::expr::v1alpha1::Constant& c case google::api::expr::v1alpha1::Constant::kStringValue: return Constant(constant.string_value()); case google::api::expr::v1alpha1::Constant::kBytesValue: - return Constant(constant.bytes_value()); + return Constant(Bytes{constant.bytes_value()}); case google::api::expr::v1alpha1::Constant::kDurationValue: return Constant(absl::Seconds(constant.duration_value().seconds()) + absl::Nanoseconds(constant.duration_value().nanos())); @@ -57,8 +58,7 @@ absl::StatusOr ToNative(const google::api::expr::v1alpha1::Constant& c absl::FromUnixSeconds(constant.timestamp_value().seconds()) + absl::Nanoseconds(constant.timestamp_value().nanos())); default: - return absl::InvalidArgumentError( - "Illegal type supplied for google::api::expr::v1alpha1::Constant."); + return absl::InvalidArgumentError("Unsupported constant type"); } } @@ -76,21 +76,24 @@ absl::StatusOr ToNative(const google::api::expr::v1alpha1::Expr::Select& select) { - auto native_operand = ToNative(select.operand()); - if (!native_operand.ok()) { - return native_operand.status(); - } - return Select(std::make_unique(*std::move(native_operand)), - select.field(), select.test_only()); +absl::StatusOr ConvertSelect(const exprpb::Expr::Select& select, - std::stack& stack) { +absl::StatusOr" + line_offsets: 89 + positions: { + key: 1 + value: 3 + } + positions: { + key: 2 + value: 4 + } + positions: { + key: 3 + value: 7 + } + positions: { + key: 4 + value: 3 + } + positions: { + key: 5 + value: 25 + } + positions: { + key: 6 + value: 26 + } + positions: { + key: 7 + value: 29 + } + positions: { + key: 8 + value: 25 + } + positions: { + key: 9 + value: 19 + } + positions: { + key: 10 + value: 44 + } + positions: { + key: 11 + value: 54 + } + positions: { + key: 12 + value: 55 + } + positions: { + key: 13 + value: 58 + } + positions: { + key: 14 + value: 70 + } + positions: { + key: 15 + value: 73 + } + positions: { + key: 16 + value: 54 + } + positions: { + key: 17 + value: 85 + } + positions: { + key: 18 + value: 87 + } + positions: { + key: 19 + value: 41 + } + macro_calls: { + key: 4 + value: { + call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 8 + value: { + call_expr: { + function: "has" + args: { + id: 7 + select_expr: { + operand: { + id: 6 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 16 + value: { + call_expr: { + target: { + id: 10 + ident_expr: { + name: "math" + } + } + function: "least" + args: { + id: 13 + select_expr: { + operand: { + id: 12 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } +} +expr: { + id: 19 + call_expr: { + function: "_||_" + args: { + id: 9 + call_expr: { + function: "_||_" + args: { + id: 4 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + args: { + id: 8 + select_expr: { + operand: { + id: 6 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + } + } + args: { + id: 17 + call_expr: { + function: "_<_" + args: { + id: 16 + call_expr: { + function: "math.@min" + args: { + id: 13 + select_expr: { + operand: { + id: 12 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 15 + select_expr: { + operand: { + id: 14 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + args: { + id: 18 + const_expr: { + int64_value: 0 + } + } + } + } + } +} diff --git a/tools/testdata/macro_single_reference.textproto b/tools/testdata/macro_single_reference.textproto new file mode 100644 index 000000000..f34c21ad9 --- /dev/null +++ b/tools/testdata/macro_single_reference.textproto @@ -0,0 +1,81 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# has(msg.old_field) +reference_map: { + key: 2 + value: { + name: "msg" + } +} +type_map: { + key: 2 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: STRING + } + } + } +} +type_map: { + key: 4 + value: { + primitive: BOOL + } +} +source_info: { + location: "" + line_offsets: 15 + positions: { + key: 1 + value: 3 + } + positions: { + key: 2 + value: 4 + } + positions: { + key: 3 + value: 7 + } + positions: { + key: 4 + value: 3 + } + macro_calls: { + key: 4 + value: { + call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } +} +expr: { + id: 4 + select_expr: { + operand: { + id: 2 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } +} diff --git a/tools/testdata/msg_new_field.textproto b/tools/testdata/msg_new_field.textproto new file mode 100644 index 000000000..3676d03a0 --- /dev/null +++ b/tools/testdata/msg_new_field.textproto @@ -0,0 +1,52 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# msg.new_field +reference_map: { + key: 1 + value: { + name: "msg" + } +} +type_map: { + key: 1 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: STRING + } + } + } +} +type_map: { + key: 2 + value: { + primitive: STRING + } +} +source_info: { + location: "" + line_offsets: 10 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 3 + } +} +expr: { + id: 2 + select_expr: { + operand: { + id: 1 + ident_expr: { + name: "msg" + } + } + field: "new_field" + } +} diff --git a/tools/testdata/msg_new_field_int.textproto b/tools/testdata/msg_new_field_int.textproto new file mode 100644 index 000000000..c7fd9bb43 --- /dev/null +++ b/tools/testdata/msg_new_field_int.textproto @@ -0,0 +1,52 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# msg.new_field +reference_map: { + key: 1 + value: { + name: "msg" + } +} +type_map: { + key: 1 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 2 + value: { + primitive: INT64 + } +} +source_info: { + location: "" + line_offsets: 14 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 3 + } +} +expr: { + id: 2 + select_expr: { + operand: { + id: 1 + ident_expr: { + name: "msg" + } + } + field: "new_field" + } +} From 66c5e224cd4593a27a2500cd0043d685d56550e1 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 14 Apr 2023 21:35:46 +0000 Subject: [PATCH 215/303] Update flat expr builder to use cel::RuntimeOptions directly instead of copying individual feature flags. PiperOrigin-RevId: 524388893 --- eval/compiler/BUILD | 2 + eval/compiler/flat_expr_builder.cc | 58 +++--- eval/compiler/flat_expr_builder.h | 56 ------ .../flat_expr_builder_comprehensions_test.cc | 2 - ...ilder_short_circuiting_conformance_test.cc | 10 +- eval/compiler/flat_expr_builder_test.cc | 168 +++++++++++------- eval/eval/evaluator_core_test.cc | 1 - .../portable_cel_expr_builder_factory.cc | 11 -- 8 files changed, 138 insertions(+), 170 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 634d4f811..af9be3086 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -93,6 +93,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -262,6 +263,7 @@ cc_test( "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index aa58c3f0b..8dca451da 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -271,15 +271,13 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { FlatExprVisitor( const google::api::expr::runtime::Resolver& resolver, google::api::expr::runtime::ExecutionPath* path, - const cel::RuntimeOptions& options, bool short_circuiting, + const cel::RuntimeOptions& options, const absl::flat_hash_map>& constant_idents, - google::protobuf::Arena* constant_arena, bool enable_comprehension, - bool enable_comprehension_list_append, + google::protobuf::Arena* constant_arena, bool enable_comprehension_vulnerability_check, - bool enable_wrapper_type_null_unboxing, google::api::expr::runtime::BuilderWarnings* warnings, - std::set* iter_variable_names, bool enable_regex, - bool enable_regex_precompilation, int regex_max_program_size, + std::set* iter_variable_names, + bool enable_regex_precompilation, const absl::flat_hash_map* reference_map, google::protobuf::Arena* arena) @@ -288,23 +286,17 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), options_(options), - short_circuiting_(short_circuiting), constant_idents_(constant_idents), constant_arena_(constant_arena), - enable_comprehension_(enable_comprehension), - enable_comprehension_list_append_(enable_comprehension_list_append), enable_comprehension_vulnerability_check_( enable_comprehension_vulnerability_check), - enable_wrapper_type_null_unboxing_(enable_wrapper_type_null_unboxing), builder_warnings_(warnings), iter_variable_names_(iter_variable_names), - enable_regex_(enable_regex), enable_regex_precompilation_(enable_regex_precompilation), - regex_program_builder_(regex_max_program_size), + regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), arena_(arena) { DCHECK(iter_variable_names_); - static_cast(options_); // TODO(issues/5): follow-up will use this. } void PreVisitExpr(const cel::ast::internal::Expr* expr, @@ -453,7 +445,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { } AddStep(CreateSelectStep(*select_expr, expr->id(), select_path, - enable_wrapper_type_null_unboxing_)); + options_.enable_empty_wrapper_null_unboxing)); } // Call node handler group. @@ -470,14 +462,14 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { std::unique_ptr cond_visitor; if (call_expr->function() == google::api::expr::runtime::builtin::kAnd) { cond_visitor = std::make_unique( - this, /* cond_value= */ false, short_circuiting_); + this, /* cond_value= */ false, options_.short_circuiting); } else if (call_expr->function() == google::api::expr::runtime::builtin::kOr) { cond_visitor = std::make_unique( - this, /* cond_value= */ true, short_circuiting_); + this, /* cond_value= */ true, options_.short_circuiting); } else if (call_expr->function() == google::api::expr::runtime::builtin::kTernary) { - if (short_circuiting_) { + if (options_.short_circuiting) { cond_visitor = std::make_unique(this); } else { cond_visitor = std::make_unique(this); @@ -521,7 +513,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { // Check to see if this is regular expression matching and the pattern is a // constant. - if (enable_regex_ && enable_regex_precompilation_ && + if (options_.enable_regex && enable_regex_precompilation_ && IsOptimizeableMatchesCall(*expr, *call_expr)) { auto program = regex_program_builder_.BuildRegexProgram( GetConstantString(call_expr->args().back())); @@ -535,7 +527,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { // Check to see if this is a special case of add that should really be // treated as a list append - if (enable_comprehension_list_append_ && + if (options_.enable_comprehension_list_append && call_expr->function() == google::api::expr::runtime::builtin::kAdd && call_expr->args().size() == 2 && !comprehension_stack_.empty()) { const cel::ast::internal::Comprehension* comprehension = @@ -597,7 +589,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { if (!progress_status_.ok()) { return; } - if (!ValidateOrError(enable_comprehension_, + if (!ValidateOrError(options_.enable_comprehension, "Comprehension support is disabled")) { return; } @@ -621,7 +613,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { comprehension_stack_.push(comprehension); cond_visitor_stack_.push( {expr, std::make_unique( - this, short_circuiting_, + this, options_.short_circuiting, enable_comprehension_vulnerability_check_)}); auto cond_visitor = FindCondVisitor(expr); cond_visitor->PreVisit(expr); @@ -675,7 +667,8 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { if (!progress_status_.ok()) { return; } - if (enable_comprehension_list_append_ && !comprehension_stack_.empty() && + if (options_.enable_comprehension_list_append && + !comprehension_stack_.empty() && &(comprehension_stack_.top()->accu_init()) == expr) { AddStep(CreateCreateMutableListStep(*list_expr, expr->id())); return; @@ -827,23 +820,17 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { const cel::ast::internal::Expr* resolved_select_expr_; const cel::RuntimeOptions& options_; - bool short_circuiting_; const absl::flat_hash_map>& constant_idents_; google::protobuf::Arena* constant_arena_; - bool enable_comprehension_; - bool enable_comprehension_list_append_; std::stack comprehension_stack_; bool enable_comprehension_vulnerability_check_; - bool enable_wrapper_type_null_unboxing_; - google::api::expr::runtime::BuilderWarnings* builder_warnings_; std::set* iter_variable_names_; - bool enable_regex_; bool enable_regex_precompilation_; RegexProgramBuilder regex_program_builder_; const absl::flat_hash_map* const @@ -1273,9 +1260,10 @@ absl::StatusOr> FlatExprBuilder::CreateExpressionImpl( cel::ast::Ast& ast, std::vector* warnings) const { ExecutionPath execution_path; - BuilderWarnings warnings_builder(fail_on_warnings_); + BuilderWarnings warnings_builder(options_.fail_on_warnings); Resolver resolver(container(), GetRegistry()->InternalGetRegistry(), - GetTypeRegistry(), enable_qualified_type_identifiers_); + GetTypeRegistry(), + options_.enable_qualified_type_identifiers); absl::flat_hash_map> constant_idents; auto& ast_impl = AstImpl::CastFromPublicAst(ast); const cel::ast::internal::Expr* effective_expr = &ast_impl.root_expr(); @@ -1315,12 +1303,10 @@ FlatExprBuilder::CreateExpressionImpl( std::set iter_variable_names; FlatExprVisitor visitor( - resolver, &execution_path, options_, shortcircuiting_, constant_idents, - constant_arena_, enable_comprehension_, enable_comprehension_list_append_, - enable_comprehension_vulnerability_check_, - enable_wrapper_type_null_unboxing_, &warnings_builder, - &iter_variable_names, enable_regex_, enable_regex_precompilation_, - regex_max_program_size_, &ast_impl.reference_map(), arena.get()); + resolver, &execution_path, options_, constant_idents, constant_arena_, + enable_comprehension_vulnerability_check_, &warnings_builder, + &iter_variable_names, enable_regex_precompilation_, + &ast_impl.reference_map(), arena.get()); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 2b0113617..89c745e7f 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -55,10 +55,6 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_unknown_function_results_ = enabled; } - // set_shortcircuiting regulates shortcircuiting of some expressions. - // Be default shortcircuiting is enabled. - void set_shortcircuiting(bool enabled) { shortcircuiting_ = enabled; } - // Toggle constant folding optimization. By default it is not enabled. // The provided arena is used to hold the generated constants. void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { @@ -66,40 +62,10 @@ class FlatExprBuilder : public CelExpressionBuilder { constant_arena_ = arena; } - void set_enable_comprehension(bool enabled) { - enable_comprehension_ = enabled; - } - void set_comprehension_max_iterations(int max_iterations) { comprehension_max_iterations_ = max_iterations; } - // Warnings (e.g. no function bound) fail immediately. - void set_fail_on_warnings(bool should_fail) { - fail_on_warnings_ = should_fail; - } - - // set_enable_qualified_type_identifiers controls whether select expressions - // may be treated as constant type identifiers during CelExpression creation. - void set_enable_qualified_type_identifiers(bool enabled) { - enable_qualified_type_identifiers_ = enabled; - } - - // set_enable_comprehension_list_append controls whether the FlatExprBuilder - // will attempt to optimize list concatenation within map() and filter() - // macro comprehensions as an append of results on the `accu_var` rather than - // as a reassignment of the `accu_var` to the concatenation of - // `accu_var` + [elem]. - // - // Before enabling, ensure that `#list_append` is not a function declared - // within your runtime, and that your CEL expressions retain their integer - // identifiers. - // - // This option is not safe for use with hand-rolled ASTs. - void set_enable_comprehension_list_append(bool enabled) { - enable_comprehension_list_append_ = enabled; - } - // set_enable_comprehension_vulnerability_check inspects comprehension // sub-expressions for the presence of potential memory exhaustion. // @@ -112,14 +78,6 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_comprehension_vulnerability_check_ = enabled; } - // If set_enable_wrapper_type_null_unboxing is enabled, the evaluator will - // return null for well known wrapper type fields if they are unset. - // The default is disabled and follows protobuf behavior (returning the - // proto default for the wrapped type). - void set_enable_wrapper_type_null_unboxing(bool enabled) { - enable_wrapper_type_null_unboxing_ = enabled; - } - // If enable_heterogeneous_equality is enabled, the evaluator will use // hetergeneous equality semantics. This includes the == operator and numeric // index lookups in containers. @@ -138,16 +96,10 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_qualified_identifier_rewrites; } - void set_enable_regex(bool enable) { enable_regex_ = enable; } - void set_enable_regex_precompilation(bool enable) { enable_regex_precompilation_ = enable; } - void set_regex_max_program_size(int regex_max_program_size) { - regex_max_program_size_ = regex_max_program_size; - } - absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -179,16 +131,8 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_unknowns_ = false; bool enable_unknown_function_results_ = false; bool enable_missing_attribute_errors_ = false; - bool shortcircuiting_ = true; - bool enable_comprehension_ = true; int comprehension_max_iterations_ = 0; - bool fail_on_warnings_ = true; - bool enable_qualified_type_identifiers_ = false; - bool enable_comprehension_list_append_ = false; - bool enable_wrapper_type_null_unboxing_ = false; bool enable_heterogeneous_equality_ = false; - bool enable_regex_ = false; - int regex_max_program_size_ = -1; bool enable_qualified_identifier_rewrites_ = false; bool enable_regex_precompilation_ = false; diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 66baf200f..aa63a5109 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -54,7 +54,6 @@ TEST(FlatExprBuilderComprehensionsTest, NestedComp) { cel::RuntimeOptions options; options.enable_comprehension_list_append = true; FlatExprBuilder builder(options); - builder.set_enable_comprehension_list_append(true); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].filter(x, [3, 4].all(y, x < y))")); @@ -74,7 +73,6 @@ TEST(FlatExprBuilderComprehensionsTest, MapComp) { cel::RuntimeOptions options; options.enable_comprehension_list_append = true; FlatExprBuilder builder(options); - builder.set_enable_comprehension_list_append(true); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index acb502a38..dbb92029e 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -16,6 +16,7 @@ #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -97,8 +98,13 @@ class ShortCircuitingTest : public testing::TestWithParam { ShortCircuitingTest() {} std::unique_ptr GetBuilder( bool enable_unknowns = false) { - auto result = std::make_unique(); - result->set_shortcircuiting(GetParam()); + cel::RuntimeOptions options; + options.short_circuiting = GetParam(); + if (enable_unknowns) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + } + auto result = std::make_unique(options); if (enable_unknowns) { result->set_enable_unknown_function_results(true); result->set_enable_unknowns(true); diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 527022665..1059846a5 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -58,6 +58,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -265,8 +266,6 @@ TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { Expr expr; SourceInfo source_info; - FlatExprBuilder builder; - auto* call = expr.mutable_call_expr(); call->set_function(builtin::kTernary); call->mutable_target()->mutable_const_expr()->set_string_value("random"); @@ -274,15 +273,26 @@ TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { call->add_args()->mutable_const_expr()->set_int64_value(1); call->add_args()->mutable_const_expr()->set_int64_value(2); - EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid argument count"))); + { + cel::RuntimeOptions options; + options.short_circuiting = true; + FlatExprBuilder builder(options); + + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); + } // Disable short-circuiting to ensure that a different visitor is used. - builder.set_shortcircuiting(false); - EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid argument count"))); + { + cel::RuntimeOptions options; + options.short_circuiting = false; + FlatExprBuilder builder(options); + + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid argument count"))); + } } TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { @@ -297,8 +307,9 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); - FlatExprBuilder builder; - builder.set_fail_on_warnings(false); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + FlatExprBuilder builder(options); std::vector warnings; // Concat function not registered. @@ -335,38 +346,54 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { auto arg2 = call_expr->add_args(); arg2->mutable_call_expr()->set_function("recorder2"); - FlatExprBuilder builder; - auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + Activation activation; + google::protobuf::Arena arena; - int count1 = 0; - int count2 = 0; + // Shortcircuiting on + { + cel::RuntimeOptions options; + options.short_circuiting = true; + FlatExprBuilder builder(options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder1", &count1))); - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder2", &count2))); + int count1 = 0; + int count2 = 0; - // Shortcircuiting on. - ASSERT_OK_AND_ASSIGN(auto cel_expr_on, - builder.CreateExpression(&expr, &source_info)); - Activation activation; - google::protobuf::Arena arena; - auto eval_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_OK(eval_on); + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1))); + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2))); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_on, + builder.CreateExpression(&expr, &source_info)); + ASSERT_OK(cel_expr_on->Evaluate(activation, &arena)); - EXPECT_THAT(count1, Eq(1)); - EXPECT_THAT(count2, Eq(0)); + EXPECT_THAT(count1, Eq(1)); + EXPECT_THAT(count2, Eq(0)); + } // Shortcircuiting off. - builder.set_shortcircuiting(false); - ASSERT_OK_AND_ASSIGN(auto cel_expr_off, - builder.CreateExpression(&expr, &source_info)); - count1 = 0; - count2 = 0; + { + cel::RuntimeOptions options; + options.short_circuiting = false; + FlatExprBuilder builder(options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count1 = 0; + int count2 = 0; - ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); - EXPECT_THAT(count1, Eq(1)); - EXPECT_THAT(count2, Eq(1)); + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1))); + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2))); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_off, + builder.CreateExpression(&expr, &source_info)); + + ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); + EXPECT_THAT(count1, Eq(1)); + EXPECT_THAT(count2, Eq(1)); + } } TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { @@ -386,32 +413,46 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { ->mutable_const_expr() ->set_bool_value(false); comprehension_expr->mutable_loop_step()->mutable_call_expr()->set_function( - "loop_step"); + "recorder_function1"); comprehension_expr->mutable_result()->mutable_const_expr()->set_bool_value( false); - FlatExprBuilder builder; - auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); - - int count = 0; - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("loop_step", &count))); - - // Shortcircuiting on. - ASSERT_OK_AND_ASSIGN(auto cel_expr_on, - builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; - ASSERT_OK(cel_expr_on->Evaluate(activation, &arena)); - EXPECT_THAT(count, Eq(0)); - // Shortcircuiting off. - builder.set_shortcircuiting(false); - ASSERT_OK_AND_ASSIGN(auto cel_expr_off, - builder.CreateExpression(&expr, &source_info)); - count = 0; - ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); - EXPECT_THAT(count, Eq(3)); + // shortcircuiting on + { + cel::RuntimeOptions options; + options.short_circuiting = true; + FlatExprBuilder builder(options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count = 0; + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count))); + + ASSERT_OK_AND_ASSIGN(auto cel_expr_on, + builder.CreateExpression(&expr, &source_info)); + + ASSERT_OK(cel_expr_on->Evaluate(activation, &arena)); + EXPECT_THAT(count, Eq(0)); + } + + // shortcircuiting off + { + cel::RuntimeOptions options; + options.short_circuiting = false; + FlatExprBuilder builder(options); + auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); + + int count = 0; + ASSERT_OK(builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count))); + ASSERT_OK_AND_ASSIGN(auto cel_expr_off, + builder.CreateExpression(&expr, &source_info)); + ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); + EXPECT_THAT(count, Eq(3)); + } } TEST(FlatExprBuilderTest, IdentExprUnsetName) { @@ -808,8 +849,9 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); - FlatExprBuilder builder; - builder.set_fail_on_warnings(false); + cel::RuntimeOptions options; + options.fail_on_warnings = false; + FlatExprBuilder builder(options); std::vector build_warnings; builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; @@ -1741,8 +1783,9 @@ TEST(FlatExprBuilderTest, NullUnboxingEnabled) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("message.int32_wrapper_value")); - FlatExprBuilder builder; - builder.set_enable_wrapper_type_null_unboxing(true); + cel::RuntimeOptions options; + options.enable_empty_wrapper_null_unboxing = true; + FlatExprBuilder builder(options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1761,8 +1804,9 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("message.int32_wrapper_value")); - FlatExprBuilder builder; - builder.set_enable_wrapper_type_null_unboxing(false); + cel::RuntimeOptions options; + options.enable_empty_wrapper_null_unboxing = false; + FlatExprBuilder builder(options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 3cd93544e..d73df29d7 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -243,7 +243,6 @@ TEST(EvaluatorCoreTest, TraceTest) { options.short_circuiting = false; FlatExprBuilder builder(options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); - builder.set_shortcircuiting(false); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 9d1ce5055..ab6d97fe5 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -40,21 +40,10 @@ std::unique_ptr CreatePortableExprBuilder( builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); - builder->set_shortcircuiting(options.short_circuiting); - builder->set_enable_comprehension(options.enable_comprehension); - builder->set_enable_comprehension_list_append( - options.enable_comprehension_list_append); builder->set_comprehension_max_iterations( options.comprehension_max_iterations); - builder->set_fail_on_warnings(options.fail_on_warnings); - builder->set_enable_qualified_type_identifiers( - options.enable_qualified_type_identifiers); - builder->set_enable_wrapper_type_null_unboxing( - options.enable_empty_wrapper_null_unboxing); builder->set_enable_heterogeneous_equality( options.enable_heterogeneous_equality); - builder->set_enable_regex(options.enable_regex); - builder->set_regex_max_program_size(options.regex_max_program_size); switch (options.unknown_processing) { case UnknownProcessingOptions::kAttributeAndFunction: builder->set_enable_unknown_function_results(true); From 26bb9bad67056fc38291960792a8c0293fa2b84d Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 14 Apr 2023 22:17:43 +0000 Subject: [PATCH 216/303] Update evaluation frame to use cel::RuntimeOptions directly instead of copying individual options. PiperOrigin-RevId: 524398350 --- eval/compiler/flat_expr_builder.cc | 5 +- eval/compiler/flat_expr_builder.h | 29 -------- ...ilder_short_circuiting_conformance_test.cc | 4 - eval/compiler/flat_expr_builder_test.cc | 15 ++-- eval/eval/BUILD | 11 ++- eval/eval/comprehension_step_test.cc | 8 +- eval/eval/const_value_step_test.cc | 2 +- eval/eval/container_access_step_test.cc | 7 +- eval/eval/create_list_step_test.cc | 21 ++++-- eval/eval/create_struct_step_test.cc | 23 ++++-- eval/eval/evaluator_core.cc | 5 +- eval/eval/evaluator_core.h | 52 +++++-------- eval/eval/evaluator_core_test.cc | 20 ++--- eval/eval/function_step_test.cc | 74 ++++++++----------- eval/eval/ident_step_test.cc | 25 ++++--- eval/eval/logic_step_test.cc | 10 ++- eval/eval/select_step_test.cc | 36 +++++---- eval/eval/shadowable_value_step_test.cc | 2 +- eval/eval/ternary_step_test.cc | 10 ++- .../portable_cel_expr_builder_factory.cc | 19 ----- 20 files changed, 173 insertions(+), 205 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 8dca451da..2aca644cc 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1322,10 +1322,7 @@ FlatExprBuilder::CreateExpressionImpl( std::unique_ptr expression_impl = std::make_unique( std::move(execution_path), GetTypeRegistry(), options_, - comprehension_max_iterations_, std::move(iter_variable_names), - enable_unknowns_, enable_unknown_function_results_, - enable_missing_attribute_errors_, enable_heterogeneous_equality_, - std::move(arena)); + std::move(iter_variable_names), std::move(arena)); if (warnings != nullptr) { *warnings = std::move(warnings_builder).warnings(); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 89c745e7f..36104ddaa 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -40,21 +40,6 @@ class FlatExprBuilder : public CelExpressionBuilder { // Create a flat expr builder with defaulted options. FlatExprBuilder() : CelExpressionBuilder() {} - // set_enable_unknowns controls support for unknowns in expressions created. - void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } - - // set_enable_missing_attribute_errors support for error injection in - // expressions created. - void set_enable_missing_attribute_errors(bool enabled) { - enable_missing_attribute_errors_ = enabled; - } - - // set_enable_unknown_function_results controls support for unknown function - // results. - void set_enable_unknown_function_results(bool enabled) { - enable_unknown_function_results_ = enabled; - } - // Toggle constant folding optimization. By default it is not enabled. // The provided arena is used to hold the generated constants. void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { @@ -62,10 +47,6 @@ class FlatExprBuilder : public CelExpressionBuilder { constant_arena_ = arena; } - void set_comprehension_max_iterations(int max_iterations) { - comprehension_max_iterations_ = max_iterations; - } - // set_enable_comprehension_vulnerability_check inspects comprehension // sub-expressions for the presence of potential memory exhaustion. // @@ -78,13 +59,6 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_comprehension_vulnerability_check_ = enabled; } - // If enable_heterogeneous_equality is enabled, the evaluator will use - // hetergeneous equality semantics. This includes the == operator and numeric - // index lookups in containers. - void set_enable_heterogeneous_equality(bool enabled) { - enable_heterogeneous_equality_ = enabled; - } - // If enable_qualified_identifier_rewrites is true, the evaluator will attempt // to disambiguate namespace qualified identifiers. // @@ -128,9 +102,6 @@ class FlatExprBuilder : public CelExpressionBuilder { cel::RuntimeOptions options_; - bool enable_unknowns_ = false; - bool enable_unknown_function_results_ = false; - bool enable_missing_attribute_errors_ = false; int comprehension_max_iterations_ = 0; bool enable_heterogeneous_equality_ = false; diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index dbb92029e..c346c6586 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -105,10 +105,6 @@ class ShortCircuitingTest : public testing::TestWithParam { cel::UnknownProcessingOptions::kAttributeAndFunction; } auto result = std::make_unique(options); - if (enable_unknowns) { - result->set_enable_unknown_function_results(true); - result->set_enable_unknowns(true); - } return result; } }; diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 1059846a5..d9d937886 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1359,8 +1359,9 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { })", &expr); - FlatExprBuilder builder; - builder.set_comprehension_max_iterations(1); + cel::RuntimeOptions options; + options.comprehension_max_iterations = 1; + FlatExprBuilder builder(options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1824,8 +1825,9 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("{1: 2, 2u: 3}[1.0]")); - FlatExprBuilder builder; - builder.set_enable_heterogeneous_equality(true); + cel::RuntimeOptions options; + options.enable_heterogeneous_equality = true; + FlatExprBuilder builder(options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1841,8 +1843,9 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("{1: 2, 2u: 3}[1.0]")); - FlatExprBuilder builder; - builder.set_enable_heterogeneous_equality(false); + cel::RuntimeOptions options; + options.enable_heterogeneous_equality = false; + FlatExprBuilder builder(options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 435bccb11..4157d657f 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -20,9 +20,8 @@ cc_library( ":evaluator_stack", "//base:ast_internal", "//base:memory_manager", - "//base:type_manager", - "//base:type_provider", - "//base:value_factory", + "//base:type", + "//base:value", "//eval/internal:adapter_activation_impl", "//eval/internal:interop", "//eval/public:base_activation", @@ -34,6 +33,7 @@ cc_library( "//extensions/protobuf:memory_manager", "//internal:casts", "//runtime:activation_interface", + "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -489,6 +489,7 @@ cc_test( "//eval/public:activation", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -543,6 +544,7 @@ cc_test( "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "@com_google_protobuf//:protobuf", ], ) @@ -570,6 +572,7 @@ cc_test( "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -595,6 +598,7 @@ cc_test( "//eval/public:unknown_attribute_set", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -767,6 +771,7 @@ cc_test( "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", + "//runtime:runtime_options", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index b3bf109ef..8b789f0ac 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -43,9 +43,13 @@ class ListKeysStepTest : public testing::Test { std::unique_ptr MakeExpression( ExecutionPath&& path, bool unknown_attributes = false) { + cel::RuntimeOptions options; + if (unknown_attributes) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + } return std::make_unique( - std::move(path), &TestTypeRegistry(), cel::RuntimeOptions{}, 0, - std::set(), unknown_attributes, unknown_attributes); + std::move(path), &TestTypeRegistry(), options, std::set()); } private: diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index 624ef0998..bd2a247ec 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -38,7 +38,7 @@ absl::StatusOr RunConstantExpression(const Expr* expr, CelExpressionFlatImpl impl(std::move(path), &google::api::expr::runtime::TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); google::api::expr::runtime::Activation activation; diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index d1e147349..0b5938ccc 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -67,8 +67,11 @@ CelValue EvaluateAttributeHelper( path.push_back(std::move(CreateIdentStep(key_expr.ident_expr(), 2).value())); path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknown); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + options.enable_heterogeneous_equality = false; + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; activation.InsertValue("container", container); diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index c2ad43553..39ca7dc1d 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -15,6 +15,7 @@ #include "eval/public/unknown_attribute_set.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -45,9 +46,12 @@ absl::StatusOr RunExpression(const std::vector& values, CEL_ASSIGN_OR_RETURN(auto step, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step)); - - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknowns); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -79,8 +83,13 @@ absl::StatusOr RunExpressionWithCelValues( CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknowns); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); return cel_expr.Evaluate(activation, arena); } @@ -102,7 +111,7 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { path.push_back(std::move(step0)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 7ca0fd9e3..9a729a574 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -76,8 +76,11 @@ absl::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, - cel::RuntimeOptions{}, 0, {}, enable_unknowns); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options, {}); Activation activation; activation.InsertValue("message", value); @@ -163,8 +166,13 @@ absl::StatusOr RunCreateMapExpression( CreateCreateStructStep(create_struct, expr1.id())); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknowns); + cel::RuntimeOptions options; + if (enable_unknowns) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); return cel_expr.Evaluate(activation, arena); } @@ -189,8 +197,11 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { expr1.id())); path.push_back(std::move(step)); - CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, - cel::RuntimeOptions{}, 0, {}, GetParam()); + cel::RuntimeOptions options; + if (GetParam()) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 6b9c3cd67..a2182f335 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -157,10 +157,7 @@ absl::StatusOr CelExpressionFlatImpl::Trace( ::cel::internal::down_cast(_state); state->Reset(); - ExecutionFrame frame( - path_, activation, &type_registry_, options_, max_iterations_, state, - enable_unknowns_, enable_unknown_function_results_, - enable_missing_attribute_errors_, enable_heterogeneous_equality_); + ExecutionFrame frame(path_, activation, &type_registry_, options_, state); EvaluatorStack* stack = &frame.value_stack(); size_t initial_stack_size = stack->size(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 266411045..9678561b9 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -38,6 +38,7 @@ #include "eval/public/unknown_attribute_set.h" #include "extensions/protobuf/memory_manager.h" #include "runtime/activation_interface.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -135,26 +136,18 @@ class ExecutionFrame { ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, const CelTypeRegistry* type_registry, - const cel::RuntimeOptions& options, int max_iterations, - CelExpressionFlatEvaluationState* state, bool enable_unknowns, - bool enable_unknown_function_results, - bool enable_missing_attribute_errors, - bool enable_heterogeneous_numeric_lookups) + const cel::RuntimeOptions& options, + CelExpressionFlatEvaluationState* state) : pc_(0UL), execution_path_(flat), activation_(activation), modern_activation_(activation), type_registry_(*type_registry), options_(options), - enable_unknowns_(enable_unknowns), - enable_unknown_function_results_(enable_unknown_function_results), - enable_missing_attribute_errors_(enable_missing_attribute_errors), - enable_heterogeneous_numeric_lookups_( - enable_heterogeneous_numeric_lookups), attribute_utility_(modern_activation_.GetUnknownAttributes(), modern_activation_.GetMissingAttributes(), state->memory_manager()), - max_iterations_(max_iterations), + max_iterations_(options_.comprehension_max_iterations), iterations_(0), state_(state) {} @@ -175,16 +168,23 @@ class ExecutionFrame { } EvaluatorStack& value_stack() { return state_->value_stack(); } - bool enable_unknowns() const { return enable_unknowns_; } + + bool enable_unknowns() const { + return options_.unknown_processing != + cel::UnknownProcessingOptions::kDisabled; + } + bool enable_unknown_function_results() const { - return enable_unknown_function_results_; + return options_.unknown_processing == + cel::UnknownProcessingOptions::kAttributeAndFunction; } + bool enable_missing_attribute_errors() const { - return enable_missing_attribute_errors_; + return options_.enable_missing_attribute_errors; } bool enable_heterogeneous_numeric_lookups() const { - return enable_heterogeneous_numeric_lookups_; + return options_.enable_heterogeneous_equality; } cel::MemoryManager& memory_manager() { return state_->memory_manager(); } @@ -259,10 +259,6 @@ class ExecutionFrame { cel::interop_internal::AdapterActivationImpl modern_activation_; const CelTypeRegistry& type_registry_; const cel::RuntimeOptions& options_; // owned by the FlatExpr instance - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; - bool enable_heterogeneous_numeric_lookups_; AttributeUtility attribute_utility_; const int max_iterations_; int iterations_; @@ -280,23 +276,14 @@ class CelExpressionFlatImpl : public CelExpression { // bound). CelExpressionFlatImpl(ExecutionPath path, const CelTypeRegistry* type_registry, - const cel::RuntimeOptions& options, int max_iterations, + const cel::RuntimeOptions& options, std::set iter_variable_names, - bool enable_unknowns = false, - bool enable_unknown_function_results = false, - bool enable_missing_attribute_errors = false, - bool enable_heterogeneous_equality = false, std::unique_ptr arena = nullptr) : arena_(std::move(arena)), path_(std::move(path)), type_registry_(*type_registry), options_(options), - max_iterations_(max_iterations), - iter_variable_names_(std::move(iter_variable_names)), - enable_unknowns_(enable_unknowns), - enable_unknown_function_results_(enable_unknown_function_results), - enable_missing_attribute_errors_(enable_missing_attribute_errors), - enable_heterogeneous_equality_(enable_heterogeneous_equality) {} + iter_variable_names_(std::move(iter_variable_names)) {} // Move-only CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; @@ -331,12 +318,7 @@ class CelExpressionFlatImpl : public CelExpression { const ExecutionPath path_; const CelTypeRegistry& type_registry_; cel::RuntimeOptions options_; - const int max_iterations_; const std::set iter_variable_names_; - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; - bool enable_heterogeneous_equality_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index d73df29d7..0aca625cd 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -72,14 +72,11 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { auto dummy_expr = std::make_unique(); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; Activation activation; CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); - ExecutionFrame frame(path, activation, &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, &state, - /*enable_unknowns=*/false, - /*enable_unknown_funcion_results=*/false, - /*enable_missing_attribute_errors=*/false, - /*enable_heterogeneous_numeric_lookups=*/true); + ExecutionFrame frame(path, activation, &TestTypeRegistry(), options, &state); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -98,12 +95,9 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { ProtoMemoryManager manager(&arena); ExecutionPath path; CelExpressionFlatEvaluationState state(path.size(), {test_iter_var}, nullptr); - ExecutionFrame frame(path, activation, &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, &state, - /*enable_unknowns=*/false, - /*enable_unknown_funcion_results=*/false, - /*enable_missing_attribute_errors=*/false, - /*enable_heterogeneous_numeric_lookups=*/true); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + ExecutionFrame frame(path, activation, &TestTypeRegistry(), options, &state); auto original = cel::interop_internal::CreateIntValue(test_value); Expr ident; @@ -163,7 +157,7 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { path.push_back(std::move(incr_step2)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index eefd34f58..7e7da15ad 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -218,25 +218,11 @@ class FunctionStepTest public: // underlying expression impl moves path std::unique_ptr GetExpression(ExecutionPath&& path) { - bool unknowns = false; - bool unknown_function_results = false; - switch (GetParam()) { - case UnknownProcessingOptions::kAttributeAndFunction: - unknowns = true; - unknown_function_results = true; - break; - case UnknownProcessingOptions::kAttributeOnly: - unknowns = true; - unknown_function_results = false; - break; - case UnknownProcessingOptions::kDisabled: - unknowns = false; - unknown_function_results = false; - break; - } + cel::RuntimeOptions options; + options.unknown_processing = GetParam(); + return std::make_unique( - std::move(path), &TestTypeRegistry(), cel::RuntimeOptions{}, 0, - std::set(), unknowns, unknown_function_results); + std::move(path), &TestTypeRegistry(), options, std::set()); } }; @@ -579,18 +565,11 @@ class FunctionStepTestUnknowns : public testing::TestWithParam { public: std::unique_ptr GetExpression(ExecutionPath&& path) { - bool unknown_functions; - switch (GetParam()) { - case UnknownProcessingOptions::kAttributeAndFunction: - unknown_functions = true; - break; - default: - unknown_functions = false; - break; - } + cel::RuntimeOptions options; + options.unknown_processing = GetParam(); + return std::make_unique( - std::move(path), &TestTypeRegistry(), cel::RuntimeOptions{}, 0, - std::set(), true, unknown_functions); + std::move(path), &TestTypeRegistry(), options, std::set()); } }; @@ -723,9 +702,10 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); - - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; @@ -767,8 +747,10 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { path.push_back(std::move(step5)); path.push_back(std::move(step6)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; @@ -810,8 +792,10 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { path.push_back(std::move(step5)); path.push_back(std::move(step6)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; @@ -848,8 +832,10 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { path.push_back(std::move(step1)); path.push_back(std::move(step2)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; @@ -931,8 +917,10 @@ TEST(FunctionStepStrictnessTest, MakeTestFunctionStep(call1, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); @@ -956,8 +944,10 @@ TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); Expr placeholder_expr; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true, true); + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index bc5e3b36b..d7207fb10 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -11,6 +11,7 @@ #include "eval/public/activation.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -31,7 +32,7 @@ TEST(IdentStepTest, TestIdentStep) { path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); Activation activation; Arena arena; @@ -58,7 +59,7 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); Activation activation; Arena arena; @@ -80,10 +81,9 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { ExecutionPath path; path.push_back(std::move(step)); - - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, - /*enable_unknowns=*/false); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; Arena arena; @@ -117,9 +117,11 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { ExecutionPath path; path.push_back(std::move(step)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, false, false, - /*enable_missing_attribute_errors=*/true); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; + options.enable_missing_attribute_errors = true; + + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; Arena arena; @@ -155,8 +157,9 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { path.push_back(std::move(step)); // Expression with unknowns enabled. - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); Activation activation; Arena arena; diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 57ee7d172..471351206 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -11,6 +11,7 @@ #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -42,8 +43,13 @@ class LogicStepTest : public testing::TestWithParam { path.push_back(std::move(step)); auto dummy_expr = std::make_unique(); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknown); + cel::RuntimeOptions options; + if (enable_unknown) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; activation.InsertValue("name0", arg0); diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index a6b4b8250..0c28168d9 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -24,6 +24,7 @@ #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" #include "testutil/util.h" namespace google::api::expr::runtime { @@ -92,9 +93,13 @@ absl::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step0)); path.push_back(std::move(step1)); + cel::RuntimeOptions runtime_options; + if (options.enable_unknowns) { + runtime_options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; + } CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, - options.enable_unknowns); + runtime_options, {}); Activation activation; activation.InsertValue("target", target); @@ -280,7 +285,7 @@ TEST(SelectStepTest, MapPresenseIsErrorTest) { path.push_back(std::move(step1)); path.push_back(std::move(step2)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, false); + cel::RuntimeOptions{}, {}); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -810,9 +815,12 @@ TEST_P(SelectStepTest, CelErrorAsArgument) { CelError error; google::protobuf::Arena arena; - bool enable_unknowns = GetParam(); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknowns); + cel::RuntimeOptions options; + if (GetParam()) { + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -846,8 +854,7 @@ TEST(SelectStepTest, DisableMissingAttributeOK) { path.push_back(std::move(step1)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, - /*enable_unknowns=*/false); + cel::RuntimeOptions{}, {}); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -887,9 +894,10 @@ TEST(SelectStepTest, UnrecoverableUnknownValueProducesError) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, false, false, - /*enable_missing_attribute_errors=*/true); + cel::RuntimeOptions options; + options.enable_missing_attribute_errors = true; + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -935,8 +943,10 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { path.push_back(*std::move(step0_status)); path.push_back(*std::move(step1_status)); - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, true); + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, + {}); { std::vector unknown_patterns; diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc index 337f2af62..9bb32d1c2 100644 --- a/eval/eval/shadowable_value_step_test.cc +++ b/eval/eval/shadowable_value_step_test.cc @@ -34,7 +34,7 @@ absl::StatusOr RunShadowableExpression(std::string identifier, path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}); + cel::RuntimeOptions{}, {}); return impl.Evaluate(activation, arena); } diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index c2177901e..524d7c84a 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -11,6 +11,7 @@ #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -53,8 +54,13 @@ class LogicStepTest : public testing::TestWithParam { CEL_ASSIGN_OR_RETURN(step, CreateTernaryStep(4)); path.push_back(std::move(step)); - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, 0, {}, enable_unknown); + cel::RuntimeOptions options; + if (enable_unknown) { + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeOnly; + } + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, + {}); Activation activation; std::string value("test"); diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index ab6d97fe5..9dabe9cd8 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -40,25 +40,6 @@ std::unique_ptr CreatePortableExprBuilder( builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); - builder->set_comprehension_max_iterations( - options.comprehension_max_iterations); - builder->set_enable_heterogeneous_equality( - options.enable_heterogeneous_equality); - switch (options.unknown_processing) { - case UnknownProcessingOptions::kAttributeAndFunction: - builder->set_enable_unknown_function_results(true); - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kAttributeOnly: - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kDisabled: - break; - } - - builder->set_enable_missing_attribute_errors( - options.enable_missing_attribute_errors); - // TODO(issues/5): These need to be abstracted to avoid bringing in too // many build dependencies by default. builder->set_enable_comprehension_vulnerability_check( From 2bff4165d62d715d9786aa49febf2069374ee792 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 14 Apr 2023 22:35:12 +0000 Subject: [PATCH 217/303] Remove iter_var name tracking from flat expr builder. This was unused after updating how iter variables are maintained in evaluation. PiperOrigin-RevId: 524402075 --- eval/compiler/flat_expr_builder.cc | 27 +++++-------------------- eval/eval/comprehension_step_test.cc | 2 +- eval/eval/container_access_step_test.cc | 3 +-- eval/eval/create_list_step_test.cc | 8 +++----- eval/eval/create_struct_step_test.cc | 7 +++---- eval/eval/evaluator_core.cc | 8 +++----- eval/eval/evaluator_core.h | 19 +++++++---------- eval/eval/evaluator_core_test.cc | 6 +++--- eval/eval/function_step_test.cc | 16 +++++++-------- eval/eval/ident_step_test.cc | 10 ++++----- eval/eval/select_step_test.cc | 15 ++++++-------- 11 files changed, 45 insertions(+), 76 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 2aca644cc..3c88a6fdb 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -276,7 +276,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { google::protobuf::Arena* constant_arena, bool enable_comprehension_vulnerability_check, google::api::expr::runtime::BuilderWarnings* warnings, - std::set* iter_variable_names, bool enable_regex_precompilation, const absl::flat_hash_map* reference_map, @@ -291,13 +290,10 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { enable_comprehension_vulnerability_check_( enable_comprehension_vulnerability_check), builder_warnings_(warnings), - iter_variable_names_(iter_variable_names), enable_regex_precompilation_(enable_regex_precompilation), regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), - arena_(arena) { - DCHECK(iter_variable_names_); - } + arena_(arena) {} void PreVisitExpr(const cel::ast::internal::Expr* expr, const cel::ast::internal::SourcePosition*) override { @@ -632,15 +628,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { auto cond_visitor = FindCondVisitor(expr); cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); - - // Save off the names of the variables we're using, such that we have a - // full set of the names from the entire evaluation tree at the end. - if (!comprehension_expr->accu_var().empty()) { - iter_variable_names_->insert(comprehension_expr->accu_var()); - } - if (!comprehension_expr->iter_var().empty()) { - iter_variable_names_->insert(comprehension_expr->iter_var()); - } } // Invoked after each argument node processed. @@ -829,8 +816,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { bool enable_comprehension_vulnerability_check_; google::api::expr::runtime::BuilderWarnings* builder_warnings_; - std::set* iter_variable_names_; - bool enable_regex_precompilation_; RegexProgramBuilder regex_program_builder_; const absl::flat_hash_map* const @@ -1301,12 +1286,10 @@ FlatExprBuilder::CreateExpressionImpl( auto arena = std::make_unique(); - std::set iter_variable_names; FlatExprVisitor visitor( resolver, &execution_path, options_, constant_idents, constant_arena_, enable_comprehension_vulnerability_check_, &warnings_builder, - &iter_variable_names, enable_regex_precompilation_, - &ast_impl.reference_map(), arena.get()); + enable_regex_precompilation_, &ast_impl.reference_map(), arena.get()); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); @@ -1320,9 +1303,9 @@ FlatExprBuilder::CreateExpressionImpl( } std::unique_ptr expression_impl = - std::make_unique( - std::move(execution_path), GetTypeRegistry(), options_, - std::move(iter_variable_names), std::move(arena)); + std::make_unique(std::move(execution_path), + GetTypeRegistry(), options_, + std::move(arena)); if (warnings != nullptr) { *warnings = std::move(warnings_builder).warnings(); diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 8b789f0ac..f5ca205b3 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -49,7 +49,7 @@ class ListKeysStepTest : public testing::Test { cel::UnknownProcessingOptions::kAttributeAndFunction; } return std::make_unique( - std::move(path), &TestTypeRegistry(), options, std::set()); + std::move(path), &TestTypeRegistry(), options); } private: diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 0b5938ccc..9ec7fd6f9 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -70,8 +70,7 @@ CelValue EvaluateAttributeHelper( cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; options.enable_heterogeneous_equality = false; - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); Activation activation; activation.InsertValue("container", container); diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 39ca7dc1d..519c4726b 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -50,8 +50,7 @@ absl::StatusOr RunExpression(const std::vector& values, if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -88,8 +87,7 @@ absl::StatusOr RunExpressionWithCelValues( options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); return cel_expr.Evaluate(activation, arena); } @@ -111,7 +109,7 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { path.push_back(std::move(step0)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 9a729a574..2dff67093 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -80,7 +80,7 @@ absl::StatusOr RunExpression(absl::string_view field, if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options, {}); + CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options); Activation activation; activation.InsertValue("message", value); @@ -171,8 +171,7 @@ absl::StatusOr RunCreateMapExpression( options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); return cel_expr.Evaluate(activation, arena); } @@ -201,7 +200,7 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { if (GetParam()) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options, {}); + CelExpressionFlatImpl cel_expr(std::move(path), &type_registry, options); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index a2182f335..af6a341be 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -31,11 +31,9 @@ absl::Status InvalidIterationStateError() { // TODO(issues/5): cel::TypeFactory and family are setup here assuming legacy // value interop. Later, these will need to be configurable by clients. CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( - size_t value_stack_size, const std::set& iter_variable_names, - google::protobuf::Arena* arena) + size_t value_stack_size, google::protobuf::Arena* arena) : memory_manager_(arena), value_stack_(value_stack_size), - iter_variable_names_(iter_variable_names), type_factory_(memory_manager_), type_manager_(type_factory_, cel::TypeProvider::Builtin()), value_factory_(type_manager_) {} @@ -141,8 +139,8 @@ bool ExecutionFrame::GetIterVar(absl::string_view name, std::unique_ptr CelExpressionFlatImpl::InitializeState( google::protobuf::Arena* arena) const { - return std::make_unique( - path_.size(), iter_variable_names_, arena); + return std::make_unique(path_.size(), + arena); } absl::StatusOr CelExpressionFlatImpl::Evaluate( diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 9678561b9..f5f94d669 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -74,12 +74,13 @@ class ExpressionStep { }; using ExecutionPath = std::vector>; +using ExecutionPathView = + absl::Span>; class CelExpressionFlatEvaluationState : public CelEvaluationState { public: - CelExpressionFlatEvaluationState( - size_t value_stack_size, const std::set& iter_variable_names, - google::protobuf::Arena* arena); + CelExpressionFlatEvaluationState(size_t value_stack_size, + google::protobuf::Arena* arena); struct ComprehensionVarEntry { absl::string_view name; @@ -101,8 +102,6 @@ class CelExpressionFlatEvaluationState : public CelEvaluationState { IterFrame& IterStackTop() { return iter_stack_[iter_stack().size() - 1]; } - std::set& iter_variable_names() { return iter_variable_names_; } - google::protobuf::Arena* arena() { return memory_manager_.arena(); } cel::MemoryManager& memory_manager() { return memory_manager_; } @@ -119,7 +118,6 @@ class CelExpressionFlatEvaluationState : public CelEvaluationState { // manager they want to use for evaluation. cel::extensions::ProtoMemoryManager memory_manager_; EvaluatorStack value_stack_; - std::set iter_variable_names_; std::vector iter_stack_; cel::TypeFactory type_factory_; cel::TypeManager type_manager_; @@ -134,7 +132,7 @@ class ExecutionFrame { // activation provides bindings between parameter names and values. // arena serves as allocation manager during the expression evaluation. - ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, + ExecutionFrame(ExecutionPathView flat, const BaseActivation& activation, const CelTypeRegistry* type_registry, const cel::RuntimeOptions& options, CelExpressionFlatEvaluationState* state) @@ -254,7 +252,7 @@ class ExecutionFrame { private: size_t pc_; // pc_ - Program Counter. Current position on execution path. - const ExecutionPath& execution_path_; + ExecutionPathView execution_path_; const BaseActivation& activation_; cel::interop_internal::AdapterActivationImpl modern_activation_; const CelTypeRegistry& type_registry_; @@ -277,13 +275,11 @@ class CelExpressionFlatImpl : public CelExpression { CelExpressionFlatImpl(ExecutionPath path, const CelTypeRegistry* type_registry, const cel::RuntimeOptions& options, - std::set iter_variable_names, std::unique_ptr arena = nullptr) : arena_(std::move(arena)), path_(std::move(path)), type_registry_(*type_registry), - options_(options), - iter_variable_names_(std::move(iter_variable_names)) {} + options_(options) {} // Move-only CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; @@ -318,7 +314,6 @@ class CelExpressionFlatImpl : public CelExpression { const ExecutionPath path_; const CelTypeRegistry& type_registry_; cel::RuntimeOptions options_; - const std::set iter_variable_names_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 0aca625cd..fa3017f53 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -75,7 +75,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; Activation activation; - CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); + CelExpressionFlatEvaluationState state(path.size(), nullptr); ExecutionFrame frame(path, activation, &TestTypeRegistry(), options, &state); EXPECT_THAT(frame.Next(), Eq(path[0].get())); @@ -94,7 +94,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { google::protobuf::Arena arena; ProtoMemoryManager manager(&arena); ExecutionPath path; - CelExpressionFlatEvaluationState state(path.size(), {test_iter_var}, nullptr); + CelExpressionFlatEvaluationState state(path.size(), nullptr); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; ExecutionFrame frame(path, activation, &TestTypeRegistry(), options, &state); @@ -157,7 +157,7 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { path.push_back(std::move(incr_step2)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 7e7da15ad..f4db07873 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -222,7 +222,7 @@ class FunctionStepTest options.unknown_processing = GetParam(); return std::make_unique( - std::move(path), &TestTypeRegistry(), options, std::set()); + std::move(path), &TestTypeRegistry(), options); } }; @@ -569,7 +569,7 @@ class FunctionStepTestUnknowns options.unknown_processing = GetParam(); return std::make_unique( - std::move(path), &TestTypeRegistry(), options, std::set()); + std::move(path), &TestTypeRegistry(), options); } }; @@ -705,7 +705,7 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; @@ -750,7 +750,7 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; @@ -795,7 +795,7 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; @@ -835,7 +835,7 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; @@ -920,7 +920,7 @@ TEST(FunctionStepStrictnessTest, cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); @@ -947,7 +947,7 @@ TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index d7207fb10..107a9b5ee 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -32,7 +32,7 @@ TEST(IdentStepTest, TestIdentStep) { path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; Arena arena; @@ -59,7 +59,7 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; Arena arena; @@ -83,7 +83,7 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { path.push_back(std::move(step)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; Arena arena; @@ -121,7 +121,7 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; options.enable_missing_attribute_errors = true; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; Arena arena; @@ -159,7 +159,7 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { // Expression with unknowns enabled. cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; Arena arena; diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 0c28168d9..93eb5afe3 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -99,7 +99,7 @@ absl::StatusOr RunExpression(const CelValue target, cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - runtime_options, {}); + runtime_options); Activation activation; activation.InsertValue("target", target); @@ -285,7 +285,7 @@ TEST(SelectStepTest, MapPresenseIsErrorTest) { path.push_back(std::move(step1)); path.push_back(std::move(step2)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -819,8 +819,7 @@ TEST_P(SelectStepTest, CelErrorAsArgument) { if (GetParam()) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -854,7 +853,7 @@ TEST(SelectStepTest, DisableMissingAttributeOK) { path.push_back(std::move(step1)); CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -896,8 +895,7 @@ TEST(SelectStepTest, UnrecoverableUnknownValueProducesError) { cel::RuntimeOptions options; options.enable_missing_attribute_errors = true; - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -945,8 +943,7 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; - CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl cel_expr(std::move(path), &TestTypeRegistry(), options); { std::vector unknown_patterns; From b37a663c0e25ad37b98013b5af3ee0ae349ed952 Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 18 Apr 2023 14:48:42 +0000 Subject: [PATCH 218/303] Internal tooling change PiperOrigin-RevId: 525144992 --- eval/testutil/test_message.proto | 2 ++ 1 file changed, 2 insertions(+) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 513fe7815..464c28165 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -67,6 +67,8 @@ message TestMessage { map bool_int32_map = 204; map int32_int32_map = 205; map uint32_uint32_map = 206; + map int32_float_map = 207; + map int64_enum_map = 208; // Well-known types. google.protobuf.Any any_value = 300; From 50c0efae5a2a91f5d096289951033f3f70f17bbf Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 18 Apr 2023 18:47:12 +0000 Subject: [PATCH 219/303] Internal tool change PiperOrigin-RevId: 525208796 --- eval/testutil/test_message.proto | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 464c28165..1501236f5 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -45,21 +45,17 @@ message TestMessage { repeated int32 int32_list = 101; repeated int64 int64_list = 102; - repeated uint32 uint32_list = 103; repeated uint64 uint64_list = 104; - repeated float float_list = 105; repeated double double_list = 106; - repeated string string_list = 107; repeated string cord_list = 108 [ctype = CORD]; repeated bytes bytes_list = 109; - repeated bool bool_list = 110; - repeated TestEnum enum_list = 111; repeated TestMessage message_list = 112; + repeated google.protobuf.Timestamp timestamp_list = 113; map int64_int32_map = 201; map uint64_int32_map = 202; @@ -69,6 +65,7 @@ message TestMessage { map uint32_uint32_map = 206; map int32_float_map = 207; map int64_enum_map = 208; + map string_timestamp_map = 209; // Well-known types. google.protobuf.Any any_value = 300; From d57e8f1642b5d1fa63f0519b3e4efef0cf850499 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 21 Apr 2023 17:19:38 +0000 Subject: [PATCH 220/303] Redo `base/operators.h` for better type safety PiperOrigin-RevId: 526072974 --- base/BUILD | 9 +- base/internal/operators.h | 44 +++- base/operators.cc | 370 +++++++++++++++++++++----------- base/operators.h | 430 ++++++++++++++++++++++++++++++++++---- base/operators_test.cc | 332 ++++++++++++----------------- 5 files changed, 826 insertions(+), 359 deletions(-) diff --git a/base/BUILD b/base/BUILD index f497212ca..cde6012a9 100644 --- a/base/BUILD +++ b/base/BUILD @@ -132,10 +132,9 @@ cc_library( "//base/internal:operators", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -144,9 +143,11 @@ cc_test( srcs = ["operators_test.cc"], deps = [ ":operators", + "//base/internal:operators", "//internal:testing", "@com_google_absl//absl/hash:hash_testing", - "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/base/internal/operators.h b/base/internal/operators.h index 84159dcca..04ffe2d79 100644 --- a/base/internal/operators.h +++ b/base/internal/operators.h @@ -1,4 +1,4 @@ -// Copyright 2021 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,10 +25,10 @@ namespace base_internal { struct OperatorData final { OperatorData() = delete; - OperatorData(const OperatorData&) = delete; - OperatorData(OperatorData&&) = delete; + OperatorData& operator=(const OperatorData&) = delete; + OperatorData& operator=(OperatorData&&) = delete; constexpr OperatorData(cel::OperatorId id, absl::string_view name, absl::string_view display_name, int precedence, @@ -46,6 +46,44 @@ struct OperatorData final { const int arity; }; +#define CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) \ + XX(LogicalNot, "!", "!_", 2, 1) \ + XX(Negate, "-", "-_", 2, 1) \ + XX(NotStrictlyFalse, "", "@not_strictly_false", 0, 1) \ + XX(OldNotStrictlyFalse, "", "__not_strictly_false__", 0, 1) + +#define CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ + XX(Equals, "==", "_==_", 5, 2) \ + XX(NotEquals, "!=", "_!=_", 5, 2) \ + XX(Less, "<", "_<_", 5, 2) \ + XX(LessEquals, "<=", "_<=_", 5, 2) \ + XX(Greater, ">", "_>_", 5, 2) \ + XX(GreaterEquals, ">=", "_>=_", 5, 2) \ + XX(In, "in", "@in", 5, 2) \ + XX(OldIn, "in", "_in_", 5, 2) \ + XX(Index, "", "_[_]", 1, 2) \ + XX(LogicalOr, "||", "_||_", 7, 2) \ + XX(LogicalAnd, "&&", "_&&_", 6, 2) \ + XX(Add, "+", "_+_", 4, 2) \ + XX(Subtract, "-", "_-_", 4, 2) \ + XX(Multiply, "*", "_*_", 3, 2) \ + XX(Divide, "/", "_/_", 3, 2) \ + XX(Modulo, "%", "_%_", 3, 2) + +#define CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ + XX(Conditional, "", "_?_:_", 8, 3) + +// Macro definining all the operators and their properties. +// (1) - The identifier. +// (2) - The display name if applicable, otherwise an empty string. +// (3) - The name. +// (4) - The precedence if applicable, otherwise 0. +// (5) - The arity. +#define CEL_INTERNAL_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ + CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) + } // namespace base_internal } // namespace cel diff --git a/base/operators.cc b/base/operators.cc index 5dc6975ec..805acc5a1 100644 --- a/base/operators.cc +++ b/base/operators.cc @@ -1,4 +1,4 @@ -// Copyright 2021 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,157 +14,287 @@ #include "base/operators.h" -#include +#include +#include #include "absl/base/attributes.h" #include "absl/base/call_once.h" -#include "absl/base/macros.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" - -// Macro definining all the operators and their properties. -// (1) - The identifier. -// (2) - The display name if applicable, otherwise an empty string. -// (3) - The name. -// (4) - The precedence if applicable, otherwise 0. -// (5) - The arity. -#define CEL_OPERATORS_ENUM(XX) \ - XX(Conditional, "", "_?_:_", 8, 3) \ - XX(LogicalOr, "||", "_||_", 7, 2) \ - XX(LogicalAnd, "&&", "_&&_", 6, 2) \ - XX(Equals, "==", "_==_", 5, 2) \ - XX(NotEquals, "!=", "_!=_", 5, 2) \ - XX(Less, "<", "_<_", 5, 2) \ - XX(LessEquals, "<=", "_<=_", 5, 2) \ - XX(Greater, ">", "_>_", 5, 2) \ - XX(GreaterEquals, ">=", "_>=_", 5, 2) \ - XX(In, "in", "@in", 5, 2) \ - XX(OldIn, "in", "_in_", 5, 2) \ - XX(Add, "+", "_+_", 4, 2) \ - XX(Subtract, "-", "_-_", 4, 2) \ - XX(Multiply, "*", "_*_", 3, 2) \ - XX(Divide, "/", "_/_", 3, 2) \ - XX(Modulo, "%", "_%_", 3, 2) \ - XX(LogicalNot, "!", "!_", 2, 1) \ - XX(Negate, "-", "-_", 2, 1) \ - XX(Index, "", "_[_]", 1, 2) \ - XX(NotStrictlyFalse, "", "@not_strictly_false", 0, 1) \ - XX(OldNotStrictlyFalse, "", "__not_strictly_false__", 0, 1) +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/operators.h" namespace cel { namespace { +using base_internal::OperatorData; + +struct OperatorDataNameComparer { + using is_transparent = void; + + bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { + return lhs->name < rhs->name; + } + + bool operator()(const OperatorData* lhs, absl::string_view rhs) const { + return lhs->name < rhs; + } + + bool operator()(absl::string_view lhs, const OperatorData* rhs) const { + return lhs < rhs->name; + } +}; + +struct OperatorDataDisplayNameComparer { + using is_transparent = void; + + bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { + return lhs->display_name < rhs->display_name; + } + + bool operator()(const OperatorData* lhs, absl::string_view rhs) const { + return lhs->display_name < rhs; + } + + bool operator()(absl::string_view lhs, const OperatorData* rhs) const { + return lhs < rhs->display_name; + } +}; + +#define CEL_OPERATORS_DATA(id, symbol, name, precedence, arity) \ + ABSL_CONST_INIT const OperatorData id##_storage = { \ + OperatorId::k##id, name, symbol, precedence, arity}; +CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DATA) +#undef CEL_OPERATORS_DATA + +#define CEL_OPERATORS_COUNT(id, symbol, name, precedence, arity) +1 + +using OperatorsArray = + std::array; + +using UnaryOperatorsArray = + std::array; + +using BinaryOperatorsArray = + std::array; + +using TernaryOperatorsArray = + std::array; + +#undef CEL_OPERATORS_COUNT + ABSL_CONST_INIT absl::once_flag operators_once_flag; -ABSL_CONST_INIT const absl::flat_hash_map* - operators_by_name = nullptr; -ABSL_CONST_INIT const absl::flat_hash_map* - operators_by_display_name = nullptr; -ABSL_CONST_INIT const absl::flat_hash_map* - unary_operators = nullptr; -ABSL_CONST_INIT const absl::flat_hash_map* - binary_operators = nullptr; + +#define CEL_OPERATORS_DO(id, symbol, name, precedence, arity) &id##_storage, + +OperatorsArray operators_by_name = { + CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +OperatorsArray operators_by_display_name = { + CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +UnaryOperatorsArray unary_operators_by_name = { + CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +UnaryOperatorsArray unary_operators_by_display_name = { + CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +BinaryOperatorsArray binary_operators_by_name = { + CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +BinaryOperatorsArray binary_operators_by_display_name = { + CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +TernaryOperatorsArray ternary_operators_by_name = { + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +TernaryOperatorsArray ternary_operators_by_display_name = { + CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; + +#undef CEL_OPERATORS_DO void InitializeOperators() { - ABSL_ASSERT(operators_by_name == nullptr); - ABSL_ASSERT(operators_by_display_name == nullptr); - ABSL_ASSERT(unary_operators == nullptr); - ABSL_ASSERT(binary_operators == nullptr); - auto operators_by_name_ptr = - std::make_unique>(); - auto operators_by_display_name_ptr = - std::make_unique>(); - auto unary_operators_ptr = - std::make_unique>(); - auto binary_operators_ptr = - std::make_unique>(); - -#define CEL_DEFINE_OPERATORS_BY_NAME(id, symbol, name, precedence, arity) \ - if constexpr (!absl::string_view(name).empty()) { \ - operators_by_name_ptr->insert({name, Operator::id()}); \ - } - CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATORS_BY_NAME) -#undef CEL_DEFINE_OPERATORS_BY_NAME - -#define CEL_DEFINE_OPERATORS_BY_SYMBOL(id, symbol, name, precedence, arity) \ - if constexpr (!absl::string_view(symbol).empty()) { \ - operators_by_display_name_ptr->insert({symbol, Operator::id()}); \ - } - CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATORS_BY_SYMBOL) -#undef CEL_DEFINE_OPERATORS_BY_SYMBOL - -#define CEL_DEFINE_UNARY_OPERATORS(id, symbol, name, precedence, arity) \ - if constexpr (!absl::string_view(symbol).empty() && arity == 1) { \ - unary_operators_ptr->insert({symbol, Operator::id()}); \ - } - CEL_OPERATORS_ENUM(CEL_DEFINE_UNARY_OPERATORS) -#undef CEL_DEFINE_UNARY_OPERATORS - -#define CEL_DEFINE_BINARY_OPERATORS(id, symbol, name, precedence, arity) \ - if constexpr (!absl::string_view(symbol).empty() && arity == 2) { \ - binary_operators_ptr->insert({symbol, Operator::id()}); \ - } - CEL_OPERATORS_ENUM(CEL_DEFINE_BINARY_OPERATORS) -#undef CEL_DEFINE_BINARY_OPERATORS - - operators_by_name = operators_by_name_ptr.release(); - operators_by_display_name = operators_by_display_name_ptr.release(); - unary_operators = unary_operators_ptr.release(); - binary_operators = binary_operators_ptr.release(); + std::stable_sort(operators_by_name.begin(), operators_by_name.end(), + OperatorDataNameComparer{}); + std::stable_sort(operators_by_display_name.begin(), + operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(unary_operators_by_name.begin(), + unary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(unary_operators_by_display_name.begin(), + unary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(binary_operators_by_name.begin(), + binary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(binary_operators_by_display_name.begin(), + binary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); + std::stable_sort(ternary_operators_by_name.begin(), + ternary_operators_by_name.end(), OperatorDataNameComparer{}); + std::stable_sort(ternary_operators_by_display_name.begin(), + ternary_operators_by_display_name.end(), + OperatorDataDisplayNameComparer{}); } -#define CEL_DEFINE_OPERATOR_DATA(id, symbol, name, precedence, arity) \ - ABSL_CONST_INIT constexpr base_internal::OperatorData k##id##Data( \ - OperatorId::k##id, name, symbol, precedence, arity); -CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATOR_DATA) -#undef CEL_DEFINE_OPERATOR_DATA - } // namespace -#define CEL_DEFINE_OPERATOR(id, symbol, name, precedence, arity) \ - Operator Operator::id() { return Operator(std::addressof(k##id##Data)); } -CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATOR) -#undef CEL_DEFINE_OPERATOR +UnaryOperator::UnaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kUnary); // Crask OK +} + +BinaryOperator::BinaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kBinary); // Crask OK +} + +TernaryOperator::TernaryOperator(Operator op) : data_(op.data_) { + ABSL_CHECK(op.arity() == Arity::kTernary); // Crask OK +} + +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + UnaryOperator Operator::id() { return UnaryOperator(&id##_storage); } + +CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) + +#undef CEL_UNARY_OPERATOR + +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + BinaryOperator Operator::id() { return BinaryOperator(&id##_storage); } + +CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) -absl::StatusOr Operator::FindByName(absl::string_view input) { +#undef CEL_BINARY_OPERATOR + +#define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TernaryOperator Operator::id() { return TernaryOperator(&id##_storage); } + +CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) + +#undef CEL_TERNARY_OPERATOR + +absl::optional Operator::FindByName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); - auto it = operators_by_name->find(input); - if (it != operators_by_name->end()) { - return it->second; + if (input.empty()) { + return absl::nullopt; + } + auto it = + std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), + input, OperatorDataNameComparer{}); + if (it == operators_by_name.cend() || (*it)->name != input) { + return absl::nullopt; } - return absl::NotFoundError(absl::StrCat("No such operator: ", input)); + return Operator(*it); } -absl::StatusOr Operator::FindByDisplayName(absl::string_view input) { +absl::optional Operator::FindByDisplayName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); - auto it = operators_by_display_name->find(input); - if (it != operators_by_display_name->end()) { - return it->second; + if (input.empty()) { + return absl::nullopt; } - return absl::NotFoundError(absl::StrCat("No such operator: ", input)); + auto it = std::lower_bound(operators_by_display_name.cbegin(), + operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == operators_by_name.cend() || (*it)->display_name != input) { + return absl::nullopt; + } + return Operator(*it); } -absl::StatusOr Operator::FindUnaryByDisplayName( +absl::optional UnaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); - auto it = unary_operators->find(input); - if (it != unary_operators->end()) { - return it->second; + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(unary_operators_by_name.cbegin(), + unary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == unary_operators_by_name.cend() || (*it)->name != input) { + return absl::nullopt; } - return absl::NotFoundError(absl::StrCat("No such unary operator: ", input)); + return UnaryOperator(*it); } -absl::StatusOr Operator::FindBinaryByDisplayName( +absl::optional UnaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); - auto it = binary_operators->find(input); - if (it != binary_operators->end()) { - return it->second; + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), + unary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == unary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return absl::nullopt; } - return absl::NotFoundError(absl::StrCat("No such binary operator: ", input)); + return UnaryOperator(*it); } -} // namespace cel +absl::optional BinaryOperator::FindByName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(binary_operators_by_name.cbegin(), + binary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == binary_operators_by_name.cend() || (*it)->name != input) { + return absl::nullopt; + } + return BinaryOperator(*it); +} -#undef CEL_OPERATORS_ENUM +absl::optional BinaryOperator::FindByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), + binary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == binary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return absl::nullopt; + } + return BinaryOperator(*it); +} + +absl::optional TernaryOperator::FindByName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(ternary_operators_by_name.cbegin(), + ternary_operators_by_name.cend(), input, + OperatorDataNameComparer{}); + if (it == ternary_operators_by_name.cend() || (*it)->name != input) { + return absl::nullopt; + } + return TernaryOperator(*it); +} + +absl::optional TernaryOperator::FindByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + if (input.empty()) { + return absl::nullopt; + } + auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), + ternary_operators_by_display_name.cend(), input, + OperatorDataDisplayNameComparer{}); + if (it == ternary_operators_by_display_name.cend() || + (*it)->display_name != input) { + return absl::nullopt; + } + return TernaryOperator(*it); +} + +} // namespace cel diff --git a/base/operators.h b/base/operators.h index 7cd40d911..778262c4b 100644 --- a/base/operators.h +++ b/base/operators.h @@ -1,4 +1,4 @@ -// Copyright 2021 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,12 +18,19 @@ #include #include "absl/base/attributes.h" -#include "absl/status/statusor.h" +#include "absl/base/macros.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "base/internal/operators.h" namespace cel { +enum class Arity { + kUnary = 1, + kBinary = 2, + kTernary = 3, +}; + enum class OperatorId { kConditional = 1, kLogicalAnd, @@ -48,48 +55,74 @@ enum class OperatorId { kOldNotStrictlyFalse, }; +enum class UnaryOperatorId { + kLogicalNot = static_cast(OperatorId::kLogicalNot), + kNegate = static_cast(OperatorId::kNegate), + kNotStrictlyFalse = static_cast(OperatorId::kNotStrictlyFalse), + kOldNotStrictlyFalse = static_cast(OperatorId::kOldNotStrictlyFalse), +}; + +enum class BinaryOperatorId { + kLogicalAnd = static_cast(OperatorId::kLogicalAnd), + kLogicalOr = static_cast(OperatorId::kLogicalOr), + kEquals = static_cast(OperatorId::kEquals), + kNotEquals = static_cast(OperatorId::kNotEquals), + kLess = static_cast(OperatorId::kLess), + kLessEquals = static_cast(OperatorId::kLessEquals), + kGreater = static_cast(OperatorId::kGreater), + kGreaterEquals = static_cast(OperatorId::kGreaterEquals), + kAdd = static_cast(OperatorId::kAdd), + kSubtract = static_cast(OperatorId::kSubtract), + kMultiply = static_cast(OperatorId::kMultiply), + kDivide = static_cast(OperatorId::kDivide), + kModulo = static_cast(OperatorId::kModulo), + kIndex = static_cast(OperatorId::kIndex), + kIn = static_cast(OperatorId::kIn), + kOldIn = static_cast(OperatorId::kOldIn), +}; + +enum class TernaryOperatorId { + kConditional = static_cast(OperatorId::kConditional), +}; + +class UnaryOperator; +class BinaryOperator; +class TernaryOperator; + class Operator final { public: - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Conditional(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalAnd(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalOr(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalNot(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Equals(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator NotEquals(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Less(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LessEquals(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Greater(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator GreaterEquals(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Add(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Subtract(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Multiply(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Divide(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Modulo(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Negate(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Index(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator In(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator NotStrictlyFalse(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator OldIn(); - ABSL_ATTRIBUTE_PURE_FUNCTION static Operator OldNotStrictlyFalse(); - - static absl::StatusOr FindByName(absl::string_view input); - - static absl::StatusOr FindByDisplayName(absl::string_view input); - - static absl::StatusOr FindUnaryByDisplayName( - absl::string_view input); + ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse(); + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn(); + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse(); - static absl::StatusOr FindBinaryByDisplayName( + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( absl::string_view input); - Operator() = delete; + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + Operator() = delete; Operator(const Operator&) = default; - Operator(Operator&&) = default; - Operator& operator=(const Operator&) = default; - Operator& operator=(Operator&&) = default; constexpr OperatorId id() const { return data_->id; } @@ -108,9 +141,13 @@ class Operator final { constexpr int precedence() const { return data_->precedence; } - constexpr int arity() const { return data_->arity; } + constexpr Arity arity() const { return static_cast(data_->arity); } private: + friend class UnaryOperator; + friend class BinaryOperator; + friend class TernaryOperator; + constexpr explicit Operator(const base_internal::OperatorData* data) : data_(data) {} @@ -143,7 +180,326 @@ constexpr bool operator!=(const Operator& lhs, OperatorId rhs) { template H AbslHashValue(H state, const Operator& op) { - return H::combine(std::move(state), op.id()); + return H::combine(std::move(state), static_cast(op.id())); +} + +class UnaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot() { + return Operator::LogicalNot(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate() { + return Operator::Negate(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse() { + return Operator::NotStrictlyFalse(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse() { + return Operator::OldNotStrictlyFalse(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( + absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + UnaryOperator() = delete; + UnaryOperator(const UnaryOperator&) = default; + UnaryOperator(UnaryOperator&&) = default; + UnaryOperator& operator=(const UnaryOperator&) = default; + UnaryOperator& operator=(UnaryOperator&&) = default; + + // Support for explicit casting of Operator to UnaryOperator. + // `Operator::arity()` must return `Arity::kUnary`, or this will crash. + explicit UnaryOperator(Operator op); + + constexpr UnaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 1); + return Arity::kUnary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit UnaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const UnaryOperator& lhs, const UnaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(UnaryOperatorId lhs, const UnaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const UnaryOperator& lhs, UnaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const UnaryOperator& lhs, const UnaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(UnaryOperatorId lhs, const UnaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const UnaryOperator& lhs, UnaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const UnaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +class BinaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd() { + return Operator::LogicalAnd(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr() { + return Operator::LogicalOr(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals() { + return Operator::Equals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals() { + return Operator::NotEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less() { + return Operator::Less(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals() { + return Operator::LessEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater() { + return Operator::Greater(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals() { + return Operator::GreaterEquals(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add() { + return Operator::Add(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract() { + return Operator::Subtract(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply() { + return Operator::Multiply(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide() { + return Operator::Divide(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo() { + return Operator::Modulo(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index() { + return Operator::Index(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In() { + return Operator::In(); + } + ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn() { + return Operator::OldIn(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( + absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + BinaryOperator() = delete; + BinaryOperator(const BinaryOperator&) = default; + BinaryOperator(BinaryOperator&&) = default; + BinaryOperator& operator=(const BinaryOperator&) = default; + BinaryOperator& operator=(BinaryOperator&&) = default; + + // Support for explicit casting of Operator to BinaryOperator. + // `Operator::arity()` must return `Arity::kBinary`, or this will crash. + explicit BinaryOperator(Operator op); + + constexpr BinaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 2); + return Arity::kBinary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit BinaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const BinaryOperator& lhs, + const BinaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(BinaryOperatorId lhs, const BinaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const BinaryOperator& lhs, BinaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const BinaryOperator& lhs, + const BinaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(BinaryOperatorId lhs, const BinaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const BinaryOperator& lhs, BinaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const BinaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); +} + +class TernaryOperator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional() { + return Operator::Conditional(); + } + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByName(absl::string_view input); + + ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional + FindByDisplayName(absl::string_view input); + + TernaryOperator() = delete; + TernaryOperator(const TernaryOperator&) = default; + TernaryOperator(TernaryOperator&&) = default; + TernaryOperator& operator=(const TernaryOperator&) = default; + TernaryOperator& operator=(TernaryOperator&&) = default; + + // Support for explicit casting of Operator to TernaryOperator. + // `Operator::arity()` must return `Arity::kTernary`, or this will crash. + explicit TernaryOperator(Operator op); + + constexpr TernaryOperatorId id() const { + return static_cast(data_->id); + } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr Arity arity() const { + ABSL_ASSERT(data_->arity == 3); + return Arity::kTernary; + } + + constexpr operator Operator() const { // NOLINT(google-explicit-constructor) + return Operator(data_); + } + + private: + friend class Operator; + + constexpr explicit TernaryOperator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const TernaryOperator& lhs, + const TernaryOperator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(TernaryOperatorId lhs, const TernaryOperator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const TernaryOperator& lhs, TernaryOperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const TernaryOperator& lhs, + const TernaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(TernaryOperatorId lhs, const TernaryOperator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const TernaryOperator& lhs, TernaryOperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const TernaryOperator& op) { + return H::combine(std::move(state), static_cast(op.id())); } } // namespace cel diff --git a/base/operators_test.cc b/base/operators_test.cc index b86743e7e..9feb3746e 100644 --- a/base/operators_test.cc +++ b/base/operators_test.cc @@ -1,4 +1,4 @@ -// Copyright 2021 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,250 +17,192 @@ #include #include "absl/hash/hash_testing.h" -#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/operators.h" #include "internal/testing.h" namespace cel { namespace { -using cel::internal::StatusIs; +using testing::Eq; +using testing::Optional; -TEST(Operator, TypeTraits) { - EXPECT_FALSE(std::is_default_constructible_v); - EXPECT_TRUE(std::is_copy_constructible_v); - EXPECT_TRUE(std::is_move_constructible_v); - EXPECT_TRUE(std::is_copy_assignable_v); - EXPECT_TRUE(std::is_move_assignable_v); +template +void TestOperator(Op op, OpId id, absl::string_view name, + absl::string_view display_name, int precedence, Arity arity) { + EXPECT_EQ(op.id(), id); + EXPECT_EQ(Operator(op).id(), static_cast(id)); + EXPECT_EQ(op.name(), name); + EXPECT_EQ(op.display_name(), display_name); + EXPECT_EQ(op.precedence(), precedence); + EXPECT_EQ(op.arity(), arity); + EXPECT_EQ(Operator(op).arity(), arity); + EXPECT_EQ(Op(Operator(op)), op); } -TEST(Operator, Conditional) { - EXPECT_EQ(Operator::Conditional().id(), OperatorId::kConditional); - EXPECT_EQ(Operator::Conditional().name(), "_?_:_"); - EXPECT_EQ(Operator::Conditional().display_name(), ""); - EXPECT_EQ(Operator::Conditional().precedence(), 8); - EXPECT_EQ(Operator::Conditional().arity(), 3); +void TestUnaryOperator(UnaryOperator op, UnaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kUnary); } -TEST(Operator, LogicalAnd) { - EXPECT_EQ(Operator::LogicalAnd().id(), OperatorId::kLogicalAnd); - EXPECT_EQ(Operator::LogicalAnd().name(), "_&&_"); - EXPECT_EQ(Operator::LogicalAnd().display_name(), "&&"); - EXPECT_EQ(Operator::LogicalAnd().precedence(), 6); - EXPECT_EQ(Operator::LogicalAnd().arity(), 2); +void TestBinaryOperator(BinaryOperator op, BinaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kBinary); } -TEST(Operator, LogicalOr) { - EXPECT_EQ(Operator::LogicalOr().id(), OperatorId::kLogicalOr); - EXPECT_EQ(Operator::LogicalOr().name(), "_||_"); - EXPECT_EQ(Operator::LogicalOr().display_name(), "||"); - EXPECT_EQ(Operator::LogicalOr().precedence(), 7); - EXPECT_EQ(Operator::LogicalOr().arity(), 2); +void TestTernaryOperator(TernaryOperator op, TernaryOperatorId id, + absl::string_view name, absl::string_view display_name, + int precedence) { + TestOperator(op, id, name, display_name, precedence, Arity::kTernary); } -TEST(Operator, LogicalNot) { - EXPECT_EQ(Operator::LogicalNot().id(), OperatorId::kLogicalNot); - EXPECT_EQ(Operator::LogicalNot().name(), "!_"); - EXPECT_EQ(Operator::LogicalNot().display_name(), "!"); - EXPECT_EQ(Operator::LogicalNot().precedence(), 2); - EXPECT_EQ(Operator::LogicalNot().arity(), 1); +TEST(Operator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_FALSE((std::is_convertible_v)); + EXPECT_FALSE((std::is_convertible_v)); + EXPECT_FALSE((std::is_convertible_v)); } -TEST(Operator, Equals) { - EXPECT_EQ(Operator::Equals().id(), OperatorId::kEquals); - EXPECT_EQ(Operator::Equals().name(), "_==_"); - EXPECT_EQ(Operator::Equals().display_name(), "=="); - EXPECT_EQ(Operator::Equals().precedence(), 5); - EXPECT_EQ(Operator::Equals().arity(), 2); +TEST(UnaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); } -TEST(Operator, NotEquals) { - EXPECT_EQ(Operator::NotEquals().id(), OperatorId::kNotEquals); - EXPECT_EQ(Operator::NotEquals().name(), "_!=_"); - EXPECT_EQ(Operator::NotEquals().display_name(), "!="); - EXPECT_EQ(Operator::NotEquals().precedence(), 5); - EXPECT_EQ(Operator::NotEquals().arity(), 2); +TEST(BinaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); } -TEST(Operator, Less) { - EXPECT_EQ(Operator::Less().id(), OperatorId::kLess); - EXPECT_EQ(Operator::Less().name(), "_<_"); - EXPECT_EQ(Operator::Less().display_name(), "<"); - EXPECT_EQ(Operator::Less().precedence(), 5); - EXPECT_EQ(Operator::Less().arity(), 2); +TEST(TernaryOperator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE((std::is_convertible_v)); } -TEST(Operator, LessEquals) { - EXPECT_EQ(Operator::LessEquals().id(), OperatorId::kLessEquals); - EXPECT_EQ(Operator::LessEquals().name(), "_<=_"); - EXPECT_EQ(Operator::LessEquals().display_name(), "<="); - EXPECT_EQ(Operator::LessEquals().precedence(), 5); - EXPECT_EQ(Operator::LessEquals().arity(), 2); -} +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(UnaryOperator, id) { \ + TestUnaryOperator(UnaryOperator::id(), UnaryOperatorId::k##id, name, \ + symbol, precedence); \ + } -TEST(Operator, Greater) { - EXPECT_EQ(Operator::Greater().id(), OperatorId::kGreater); - EXPECT_EQ(Operator::Greater().name(), "_>_"); - EXPECT_EQ(Operator::Greater().display_name(), ">"); - EXPECT_EQ(Operator::Greater().precedence(), 5); - EXPECT_EQ(Operator::Greater().arity(), 2); -} +CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) -TEST(Operator, GreaterEquals) { - EXPECT_EQ(Operator::GreaterEquals().id(), OperatorId::kGreaterEquals); - EXPECT_EQ(Operator::GreaterEquals().name(), "_>=_"); - EXPECT_EQ(Operator::GreaterEquals().display_name(), ">="); - EXPECT_EQ(Operator::GreaterEquals().precedence(), 5); - EXPECT_EQ(Operator::GreaterEquals().arity(), 2); -} +#undef CEL_UNARY_OPERATOR -TEST(Operator, Add) { - EXPECT_EQ(Operator::Add().id(), OperatorId::kAdd); - EXPECT_EQ(Operator::Add().name(), "_+_"); - EXPECT_EQ(Operator::Add().display_name(), "+"); - EXPECT_EQ(Operator::Add().precedence(), 4); - EXPECT_EQ(Operator::Add().arity(), 2); -} +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(BinaryOperator, id) { \ + TestBinaryOperator(BinaryOperator::id(), BinaryOperatorId::k##id, name, \ + symbol, precedence); \ + } -TEST(Operator, Subtract) { - EXPECT_EQ(Operator::Subtract().id(), OperatorId::kSubtract); - EXPECT_EQ(Operator::Subtract().name(), "_-_"); - EXPECT_EQ(Operator::Subtract().display_name(), "-"); - EXPECT_EQ(Operator::Subtract().precedence(), 4); - EXPECT_EQ(Operator::Subtract().arity(), 2); -} +CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) -TEST(Operator, Multiply) { - EXPECT_EQ(Operator::Multiply().id(), OperatorId::kMultiply); - EXPECT_EQ(Operator::Multiply().name(), "_*_"); - EXPECT_EQ(Operator::Multiply().display_name(), "*"); - EXPECT_EQ(Operator::Multiply().precedence(), 3); - EXPECT_EQ(Operator::Multiply().arity(), 2); -} +#undef CEL_BINARY_OPERATOR -TEST(Operator, Divide) { - EXPECT_EQ(Operator::Divide().id(), OperatorId::kDivide); - EXPECT_EQ(Operator::Divide().name(), "_/_"); - EXPECT_EQ(Operator::Divide().display_name(), "/"); - EXPECT_EQ(Operator::Divide().precedence(), 3); - EXPECT_EQ(Operator::Divide().arity(), 2); -} +#define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ + TEST(TernaryOperator, id) { \ + TestTernaryOperator(TernaryOperator::id(), TernaryOperatorId::k##id, name, \ + symbol, precedence); \ + } -TEST(Operator, Modulo) { - EXPECT_EQ(Operator::Modulo().id(), OperatorId::kModulo); - EXPECT_EQ(Operator::Modulo().name(), "_%_"); - EXPECT_EQ(Operator::Modulo().display_name(), "%"); - EXPECT_EQ(Operator::Modulo().precedence(), 3); - EXPECT_EQ(Operator::Modulo().arity(), 2); -} +CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) + +#undef CEL_TERNARY_OPERATOR -TEST(Operator, Negate) { - EXPECT_EQ(Operator::Negate().id(), OperatorId::kNegate); - EXPECT_EQ(Operator::Negate().name(), "-_"); - EXPECT_EQ(Operator::Negate().display_name(), "-"); - EXPECT_EQ(Operator::Negate().precedence(), 2); - EXPECT_EQ(Operator::Negate().arity(), 1); +TEST(Operator, FindByName) { + EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); + EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); + EXPECT_THAT(Operator::FindByName("in"), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByName(""), Eq(absl::nullopt)); } -TEST(Operator, Index) { - EXPECT_EQ(Operator::Index().id(), OperatorId::kIndex); - EXPECT_EQ(Operator::Index().name(), "_[_]"); - EXPECT_EQ(Operator::Index().display_name(), ""); - EXPECT_EQ(Operator::Index().precedence(), 1); - EXPECT_EQ(Operator::Index().arity(), 2); +TEST(Operator, FindByDisplayName) { + EXPECT_THAT(Operator::FindByDisplayName("-"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName(""), Eq(absl::nullopt)); } -TEST(Operator, In) { - EXPECT_EQ(Operator::In().id(), OperatorId::kIn); - EXPECT_EQ(Operator::In().name(), "@in"); - EXPECT_EQ(Operator::In().display_name(), "in"); - EXPECT_EQ(Operator::In().precedence(), 5); - EXPECT_EQ(Operator::In().arity(), 2); +TEST(UnaryOperator, FindByName) { + EXPECT_THAT(UnaryOperator::FindByName("-_"), + Optional(Eq(Operator::Negate()))); + EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName(""), Eq(absl::nullopt)); } -TEST(Operator, NotStrictlyFalse) { - EXPECT_EQ(Operator::NotStrictlyFalse().id(), OperatorId::kNotStrictlyFalse); - EXPECT_EQ(Operator::NotStrictlyFalse().name(), "@not_strictly_false"); - EXPECT_EQ(Operator::NotStrictlyFalse().display_name(), ""); - EXPECT_EQ(Operator::NotStrictlyFalse().precedence(), 0); - EXPECT_EQ(Operator::NotStrictlyFalse().arity(), 1); +TEST(UnaryOperator, FindByDisplayName) { + EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), + Optional(Eq(Operator::Negate()))); + EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); } -TEST(Operator, OldIn) { - EXPECT_EQ(Operator::OldIn().id(), OperatorId::kOldIn); - EXPECT_EQ(Operator::OldIn().name(), "_in_"); - EXPECT_EQ(Operator::OldIn().display_name(), "in"); - EXPECT_EQ(Operator::OldIn().precedence(), 5); - EXPECT_EQ(Operator::OldIn().arity(), 2); +TEST(BinaryOperator, FindByName) { + EXPECT_THAT(BinaryOperator::FindByName("_-_"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName(""), Eq(absl::nullopt)); } -TEST(Operator, OldNotStrictlyFalse) { - EXPECT_EQ(Operator::OldNotStrictlyFalse().id(), - OperatorId::kOldNotStrictlyFalse); - EXPECT_EQ(Operator::OldNotStrictlyFalse().name(), "__not_strictly_false__"); - EXPECT_EQ(Operator::OldNotStrictlyFalse().display_name(), ""); - EXPECT_EQ(Operator::OldNotStrictlyFalse().precedence(), 0); - EXPECT_EQ(Operator::OldNotStrictlyFalse().arity(), 1); +TEST(BinaryOperator, FindByDisplayName) { + EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), + Optional(Eq(Operator::Subtract()))); + EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); } -TEST(Operator, FindByName) { - auto status_or_operator = Operator::FindByName("@in"); - EXPECT_OK(status_or_operator); - EXPECT_EQ(status_or_operator.value(), Operator::In()); - status_or_operator = Operator::FindByName("_in_"); - EXPECT_OK(status_or_operator); - EXPECT_EQ(status_or_operator.value(), Operator::OldIn()); - status_or_operator = Operator::FindByName("in"); - EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +TEST(TernaryOperator, FindByName) { + EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), + Optional(Eq(TernaryOperator::Conditional()))); + EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName(""), Eq(absl::nullopt)); } -TEST(Operator, FindByDisplayName) { - auto status_or_operator = Operator::FindByDisplayName("-"); - EXPECT_OK(status_or_operator); - EXPECT_EQ(status_or_operator.value(), Operator::Subtract()); - status_or_operator = Operator::FindByDisplayName("@in"); - EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +TEST(TernaryOperator, FindByDisplayName) { + EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); } -TEST(Operator, FindUnaryByDisplayName) { - auto status_or_operator = Operator::FindUnaryByDisplayName("-"); - EXPECT_OK(status_or_operator); - EXPECT_EQ(status_or_operator.value(), Operator::Negate()); - status_or_operator = Operator::FindUnaryByDisplayName("&&"); - EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +TEST(Operator, SupportsAbslHash) { +#define CEL_OPERATOR(id, symbol, name, precedence, arity) \ + Operator(Operator::id()), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATOR)})); +#undef CEL_OPERATOR } -TEST(Operator, FindBinaryByDisplayName) { - auto status_or_operator = Operator::FindBinaryByDisplayName("-"); - EXPECT_OK(status_or_operator); - EXPECT_EQ(status_or_operator.value(), Operator::Subtract()); - status_or_operator = Operator::FindBinaryByDisplayName("!"); - EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +TEST(UnaryOperator, SupportsAbslHash) { +#define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ + UnaryOperator::id(), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR)})); +#undef CEL_UNARY_OPERATOR } -TEST(Type, SupportsAbslHash) { - EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ - Operator::Conditional(), - Operator::LogicalAnd(), - Operator::LogicalOr(), - Operator::LogicalNot(), - Operator::Equals(), - Operator::NotEquals(), - Operator::Less(), - Operator::LessEquals(), - Operator::Greater(), - Operator::GreaterEquals(), - Operator::Add(), - Operator::Subtract(), - Operator::Multiply(), - Operator::Divide(), - Operator::Modulo(), - Operator::Negate(), - Operator::Index(), - Operator::In(), - Operator::NotStrictlyFalse(), - Operator::OldIn(), - Operator::OldNotStrictlyFalse(), - })); +TEST(BinaryOperator, SupportsAbslHash) { +#define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ + BinaryOperator::id(), + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR)})); +#undef CEL_BINARY_OPERATOR } } // namespace From bb5cc7088b3ca70ff71909c8a67041cedf7a00e9 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 21 Apr 2023 19:09:56 +0000 Subject: [PATCH 221/303] Place `cel::Function` and `cel::FunctionDescriptor` in their appropriate headers Also breaks up `:functions`. PiperOrigin-RevId: 526102496 --- base/BUILD | 60 ++++++++---- base/function.h | 98 +++++++++----------- base/function_adapter.h | 2 +- base/function_adapter_test.cc | 2 +- base/{function.cc => function_descriptor.cc} | 9 +- base/function_descriptor.h | 85 +++++++++++++++++ base/function_interface.h | 70 -------------- base/function_result.h | 2 +- base/internal/BUILD | 2 +- eval/compiler/BUILD | 2 +- eval/compiler/constant_folding.cc | 2 +- eval/eval/BUILD | 10 +- eval/eval/attribute_utility.h | 2 +- eval/eval/function_step.cc | 2 +- eval/public/BUILD | 17 ++-- eval/public/cel_function.cc | 2 +- eval/public/cel_function.h | 2 +- eval/public/cel_function_registry.cc | 2 +- eval/public/cel_function_registry.h | 2 +- eval/public/logical_function_registrar.cc | 2 +- runtime/BUILD | 22 ++--- runtime/activation.cc | 2 +- runtime/activation.h | 2 +- runtime/activation_test.cc | 2 +- runtime/function_overload_reference.h | 2 +- runtime/function_provider.h | 2 +- runtime/function_registry.cc | 2 +- runtime/function_registry.h | 2 +- runtime/function_registry_test.cc | 2 +- 29 files changed, 222 insertions(+), 191 deletions(-) rename base/{function.cc => function_descriptor.cc} (93%) create mode 100644 base/function_descriptor.h delete mode 100644 base/function_interface.h diff --git a/base/BUILD b/base/BUILD index cde6012a9..96cc952c9 100644 --- a/base/BUILD +++ b/base/BUILD @@ -252,7 +252,7 @@ cc_library( deps = [ ":allocator", ":attributes", - ":functions", + ":function_result_set", ":handle", ":kind", ":memory_manager", @@ -337,43 +337,67 @@ cc_test( ) cc_library( - name = "functions", + name = "function", + hdrs = [ + "function.h", + ], + deps = [ + ":handle", + ":value", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "function_descriptor", srcs = [ - "function.cc", - "function_result_set.cc", + "function_descriptor.cc", ], hdrs = [ - "function.h", - "function_result.h", - "function_result_set.h", + "function_descriptor.h", ], deps = [ ":kind", - "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_library( - name = "ast", - hdrs = ["ast.h"], + name = "function_result", + hdrs = [ + "function_result.h", + ], + deps = [":function_descriptor"], ) cc_library( - name = "function_interface", - hdrs = ["function_interface.h"], + name = "function_result_set", + srcs = [ + "function_result_set.cc", + ], + hdrs = [ + "function_result_set.h", + ], deps = [ - ":value", - "@com_google_absl//absl/types:span", + ":function_result", + "@com_google_absl//absl/container:btree", ], ) +cc_library( + name = "ast", + hdrs = ["ast.h"], +) + cc_library( name = "function_adapter", hdrs = ["function_adapter.h"], deps = [ - ":function_interface", - ":functions", + ":function", + ":function_descriptor", ":handle", ":value", "//base/internal:function_adapter", @@ -390,9 +414,9 @@ cc_test( name = "function_adapter_test", srcs = ["function_adapter_test.cc"], deps = [ + ":function", ":function_adapter", - ":function_interface", - ":functions", + ":function_descriptor", ":handle", ":kind", ":memory_manager", diff --git a/base/function.h b/base/function.h index 7e9487e34..6b71e18a6 100644 --- a/base/function.h +++ b/base/function.h @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,70 +15,56 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ -#include -#include -#include -#include - +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "base/kind.h" +#include "base/handle.h" +#include "base/value.h" +#include "base/value_factory.h" namespace cel { -// Describes a function. -class FunctionDescriptor final { +// Interface for extension functions. +// +// The host for the CEL environment may provide implementations to define custom +// extensions functions. +// +// The interpreter expects functions to be deterministic and side-effect free. +class Function { public: - FunctionDescriptor(absl::string_view name, bool receiver_style, - std::vector types, bool is_strict = true) - : impl_(std::make_shared(name, receiver_style, std::move(types), - is_strict)) {} - - // Function name. - const std::string& name() const { return impl_->name; } - - // Whether function is receiver style i.e. true means arg0.name(args[1:]...). - bool receiver_style() const { return impl_->receiver_style; } - - // The argmument types the function accepts. - // - // TODO(issues/5): make this kinds - const std::vector& types() const { return impl_->types; } - - // if true (strict, default), error or unknown arguments are propagated - // instead of calling the function. if false (non-strict), the function may - // receive error or unknown values as arguments. - bool is_strict() const { return impl_->is_strict; } - - // Helper for matching a descriptor. This tests that the shape is the same -- - // |other| accepts the same number and types of arguments and is the same call - // style). - bool ShapeMatches(const FunctionDescriptor& other) const { - return ShapeMatches(other.receiver_style(), other.types()); - } - bool ShapeMatches(bool receiver_style, absl::Span types) const; - - bool operator==(const FunctionDescriptor& other) const; - - bool operator<(const FunctionDescriptor& other) const; - - private: - struct Impl final { - Impl(absl::string_view name, bool receiver_style, std::vector types, - bool is_strict) - : name(name), - types(std::move(types)), - receiver_style(receiver_style), - is_strict(is_strict) {} - - std::string name; - std::vector types; - bool receiver_style; - bool is_strict; + virtual ~Function() = default; + + // InvokeContext provides access to current evaluator state. + class InvokeContext final { + public: + explicit InvokeContext(ValueFactory& value_factory) + : value_factory_(value_factory) {} + + // Return the value_factory defined for the evaluation invoking the + // extension function. + cel::ValueFactory& value_factory() const { return value_factory_; } + + // TODO(issues/5): Add accessors for getting attribute stack and mutable + // value stack. + private: + cel::ValueFactory& value_factory_; }; - std::shared_ptr impl_; + // Attempt to evaluate an extension function based on the runtime arguments + // during the evaluation of a CEL expression. + // + // A non-ok status is interpreted as an unrecoverable error in evaluation ( + // e.g. data corruption). This stops evaluation and is propagated immediately. + // + // A cel::ErrorValue typed result is considered a recoverable error and + // follows CEL's logical short-circuiting behavior. + virtual absl::StatusOr> Invoke( + const InvokeContext& context, + absl::Span> args) const = 0; }; +// Legacy type, aliased to the actual type. +using FunctionEvaluationContext = Function::InvokeContext; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ diff --git a/base/function_adapter.h b/base/function_adapter.h index cd2d9aa68..d80b436e8 100644 --- a/base/function_adapter.h +++ b/base/function_adapter.h @@ -29,7 +29,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/internal/function_adapter.h" #include "base/value.h" diff --git a/base/function_adapter_test.cc b/base/function_adapter_test.cc index 7c23b9844..f5f75f7bf 100644 --- a/base/function_adapter_test.cc +++ b/base/function_adapter_test.cc @@ -22,7 +22,7 @@ #include "absl/status/statusor.h" #include "absl/time/time.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/kind.h" #include "base/memory_manager.h" diff --git a/base/function.cc b/base/function_descriptor.cc similarity index 93% rename from base/function.cc rename to base/function_descriptor.cc index ff0be8390..3ceff93f3 100644 --- a/base/function.cc +++ b/base/function_descriptor.cc @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,9 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "base/function.h" +#include "base/function_descriptor.h" #include +#include + +#include "absl/base/macros.h" +#include "absl/types/span.h" +#include "base/kind.h" namespace cel { diff --git a/base/function_descriptor.h b/base/function_descriptor.h new file mode 100644 index 000000000..499ad9a85 --- /dev/null +++ b/base/function_descriptor.h @@ -0,0 +1,85 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ +#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/kind.h" + +namespace cel { + +// Describes a function. +class FunctionDescriptor final { + public: + FunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types, bool is_strict = true) + : impl_(std::make_shared(name, receiver_style, std::move(types), + is_strict)) {} + + // Function name. + const std::string& name() const { return impl_->name; } + + // Whether function is receiver style i.e. true means arg0.name(args[1:]...). + bool receiver_style() const { return impl_->receiver_style; } + + // The argmument types the function accepts. + // + // TODO(issues/5): make this kinds + const std::vector& types() const { return impl_->types; } + + // if true (strict, default), error or unknown arguments are propagated + // instead of calling the function. if false (non-strict), the function may + // receive error or unknown values as arguments. + bool is_strict() const { return impl_->is_strict; } + + // Helper for matching a descriptor. This tests that the shape is the same -- + // |other| accepts the same number and types of arguments and is the same call + // style). + bool ShapeMatches(const FunctionDescriptor& other) const { + return ShapeMatches(other.receiver_style(), other.types()); + } + bool ShapeMatches(bool receiver_style, absl::Span types) const; + + bool operator==(const FunctionDescriptor& other) const; + + bool operator<(const FunctionDescriptor& other) const; + + private: + struct Impl final { + Impl(absl::string_view name, bool receiver_style, std::vector types, + bool is_strict) + : name(name), + types(std::move(types)), + receiver_style(receiver_style), + is_strict(is_strict) {} + + std::string name; + std::vector types; + bool receiver_style; + bool is_strict; + }; + + std::shared_ptr impl_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ diff --git a/base/function_interface.h b/base/function_interface.h deleted file mode 100644 index 6336011da..000000000 --- a/base/function_interface.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_INTERFACE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_INTERFACE_H_ - -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "base/value.h" -#include "base/value_factory.h" - -namespace cel { - -// FunctionEvaluationContext provides access to current evaluator state. -class FunctionEvaluationContext { - public: - explicit FunctionEvaluationContext(ValueFactory& value_factory) - : value_factory_(value_factory) {} - - // Return the value_factory defined for the evaluation invoking the extension - // function. - cel::ValueFactory& value_factory() const { return value_factory_; } - - // TODO(issues/5): Add accessors for getting attribute stack and mutable - // value stack. - private: - cel::ValueFactory& value_factory_; -}; - -// Interface for extension functions. -// -// The host for the CEL environment may provide implementations to define custom -// extensions functions. -// -// The interpreter expects functions to be deterministic and side-effect free. -class Function { - public: - virtual ~Function() = default; - - // Attempt to evaluate an extension function based on the runtime arguments - // during the evaluation of a CEL expression. - // - // A non-ok status is interpreted as an unrecoverable error in evaluation ( - // e.g. data corruption). This stops evaluation and is propagated immediately. - // - // A cel::ErrorValue typed result is considered a recoverable error and - // follows CEL's logical short-circuiting behavior. - virtual absl::StatusOr> Invoke( - const FunctionEvaluationContext& context, - absl::Span> args) const = 0; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_INTERFACE_H_ diff --git a/base/function_result.h b/base/function_result.h index 9bc2d6713..da6e164ea 100644 --- a/base/function_result.h +++ b/base/function_result.h @@ -18,7 +18,7 @@ #include #include -#include "base/function.h" +#include "base/function_descriptor.h" namespace cel { diff --git a/base/internal/BUILD b/base/internal/BUILD index c970b1127..cc05c632f 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -108,7 +108,7 @@ cc_library( hdrs = ["unknown_set.h"], deps = [ "//base:attributes", - "//base:functions", + "//base:function_result_set", "//internal:no_destructor", "@com_google_absl//absl/base:core_headers", ], diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index af9be3086..329e285b3 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -144,7 +144,7 @@ cc_library( ], deps = [ "//base:ast_internal", - "//base:function_interface", + "//base:function", "//base:value", "//eval/internal:errors", "//eval/internal:interop", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index e049c87da..5452aacc2 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -7,7 +7,7 @@ #include "absl/strings/str_cat.h" #include "base/ast_internal.h" -#include "base/function_interface.h" +#include "base/function.h" #include "base/values/error_value.h" #include "eval/internal/errors.h" #include "eval/internal/interop.h" diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 4157d657f..2fa4699b2 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -189,18 +189,16 @@ cc_library( ":attribute_trail", ":evaluator_core", ":expression_step_base", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:handle", "//base:kind", "//base:value", "//eval/internal:errors", "//eval/internal:interop", - "//eval/public:base_activation", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_value", - "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//internal:status_macros", "//runtime:activation_interface", @@ -703,7 +701,9 @@ cc_library( deps = [ ":attribute_trail", "//base:attributes", - "//base:functions", + "//base:function_descriptor", + "//base:function_result", + "//base:function_result_set", "//base:handle", "//base:memory_manager", "//base:value", diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 74f80b519..a7e003ac6 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -7,7 +7,7 @@ #include "google/protobuf/arena.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "base/function.h" +#include "base/function_descriptor.h" #include "base/function_result.h" #include "base/function_result_set.h" #include "base/handle.h" diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 6766a5624..230bddf2a 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -16,7 +16,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/kind.h" #include "base/value.h" diff --git a/eval/public/BUILD b/eval/public/BUILD index 305688463..b91e635f1 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -182,8 +182,8 @@ cc_library( ], deps = [ ":cel_value", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:handle", "//base:value", "//eval/internal:interop", @@ -449,10 +449,8 @@ cc_library( ":cel_function_registry", ":cel_options", "//base:function_adapter", - "//base:functions", - "//base:kind", + "//base:function_descriptor", "//base:value", - "//base:value_factory", "//eval/internal:errors", "//internal:status_macros", "@com_google_absl//absl/status", @@ -681,8 +679,8 @@ cc_library( ":cel_function", ":cel_options", ":cel_value", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:kind", "//base:type", "//base:value", @@ -1053,7 +1051,10 @@ cc_library( name = "unknown_function_result_set", srcs = ["unknown_function_result_set.cc"], hdrs = ["unknown_function_result_set.h"], - deps = ["//base:functions"], + deps = [ + "//base:function_result", + "//base:function_result_set", + ], ) cc_test( diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index 305926a8c..274b37c29 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -6,7 +6,7 @@ #include #include -#include "base/function_interface.h" +#include "base/function.h" #include "eval/internal/interop.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index b51296b09..2cc9ea0fe 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -10,7 +10,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/value.h" #include "eval/public/cel_value.h" diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 538b7dca4..01fb234f3 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -12,7 +12,7 @@ #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/type_manager.h" #include "base/type_provider.h" #include "base/value.h" diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index d8bbe632a..ced74f617 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -14,7 +14,7 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/kind.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" diff --git a/eval/public/logical_function_registrar.cc b/eval/public/logical_function_registrar.cc index a827099f1..ce03e3a2f 100644 --- a/eval/public/logical_function_registrar.cc +++ b/eval/public/logical_function_registrar.cc @@ -24,8 +24,8 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "base/function.h" #include "base/function_adapter.h" +#include "base/function_descriptor.h" #include "base/value_factory.h" #include "base/values/bool_value.h" #include "base/values/error_value.h" diff --git a/runtime/BUILD b/runtime/BUILD index fdb743d33..4762f82f0 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -37,8 +37,8 @@ cc_library( name = "function_overload_reference", hdrs = ["function_overload_reference.h"], deps = [ - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", ], ) @@ -48,7 +48,7 @@ cc_library( deps = [ ":activation_interface", ":function_overload_reference", - "//base:functions", + "//base:function_descriptor", "@com_google_absl//absl/status:statusor", ], ) @@ -61,8 +61,8 @@ cc_library( ":activation_interface", ":function_overload_reference", "//base:attributes", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:handle", "//base:value", "@com_google_absl//absl/container:flat_hash_map", @@ -81,8 +81,8 @@ cc_test( deps = [ ":activation", "//base:attributes", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:handle", "//base:memory_manager", "//base:type", @@ -104,8 +104,8 @@ cc_library( ":activation_interface", ":function_overload_reference", ":function_provider", - "//base:function_interface", - "//base:functions", + "//base:function", + "//base:function_descriptor", "//base:kind", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", @@ -125,9 +125,9 @@ cc_test( ":function_overload_reference", ":function_provider", ":function_registry", + "//base:function", "//base:function_adapter", - "//base:function_interface", - "//base:functions", + "//base:function_descriptor", "//base:kind", "//base:value", "//internal:testing", diff --git a/runtime/activation.cc b/runtime/activation.cc index 9f9a835c6..e1de45ef9 100644 --- a/runtime/activation.cc +++ b/runtime/activation.cc @@ -23,7 +23,7 @@ #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/value.h" #include "runtime/function_overload_reference.h" diff --git a/runtime/activation.h b/runtime/activation.h index 18202d8d5..01272df25 100644 --- a/runtime/activation.h +++ b/runtime/activation.h @@ -30,7 +30,7 @@ #include "absl/types/span.h" #include "base/attribute.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/value.h" #include "base/value_factory.h" diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index 5ebf02d81..6081bc3de 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -21,7 +21,7 @@ #include "absl/types/span.h" #include "base/attribute.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/handle.h" #include "base/memory_manager.h" #include "base/type_factory.h" diff --git a/runtime/function_overload_reference.h b/runtime/function_overload_reference.h index 56f64c919..c317e8dc2 100644 --- a/runtime/function_overload_reference.h +++ b/runtime/function_overload_reference.h @@ -16,7 +16,7 @@ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" namespace cel { diff --git a/runtime/function_provider.h b/runtime/function_provider.h index b33e16cf4..cca8c62aa 100644 --- a/runtime/function_provider.h +++ b/runtime/function_provider.h @@ -16,7 +16,7 @@ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ #include "absl/status/statusor.h" -#include "base/function.h" +#include "base/function_descriptor.h" #include "runtime/activation_interface.h" #include "runtime/function_overload_reference.h" diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index ff06eee14..4c16cf40e 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -27,7 +27,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/kind.h" #include "runtime/activation_interface.h" #include "runtime/function_overload_reference.h" diff --git a/runtime/function_registry.h b/runtime/function_registry.h index 7d84d8d65..c5d765974 100644 --- a/runtime/function_registry.h +++ b/runtime/function_registry.h @@ -26,7 +26,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/function.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/kind.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index ae105daa7..5618f5551 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -22,7 +22,7 @@ #include "absl/status/status.h" #include "base/function.h" #include "base/function_adapter.h" -#include "base/function_interface.h" +#include "base/function_descriptor.h" #include "base/kind.h" #include "base/value_factory.h" #include "internal/testing.h" From 85765f9c5d3a8336eb8305e306e46a4b7c4099da Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 24 Apr 2023 18:57:16 +0000 Subject: [PATCH 222/303] Migrate most remaining functions in builin_func_registrar to new function interface. Container 'in' will be addressed along with equality and comparison functions. PiperOrigin-RevId: 526713939 --- eval/public/BUILD | 2 + eval/public/builtin_func_registrar.cc | 528 +++++++++++++------------- 2 files changed, 262 insertions(+), 268 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index b91e635f1..38d6d2935 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -291,10 +291,12 @@ cc_library( ":portable_cel_function_adapter", "//base:function_adapter", "//base:handle", + "//base:type", "//base:value", "//eval/eval:mutable_list_impl", "//eval/internal:interop", "//eval/public/containers:container_backed_list_impl", + "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:overflow", "//internal:proto_time_encoding", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index b10353481..a69af1d89 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -32,9 +32,12 @@ #include "absl/types/optional.h" #include "base/function_adapter.h" #include "base/handle.h" +#include "base/type_factory.h" #include "base/value.h" #include "base/value_factory.h" #include "base/values/bytes_value.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" #include "base/values/string_value.h" #include "eval/eval/mutable_list_impl.h" #include "eval/internal/interop.h" @@ -48,6 +51,7 @@ #include "eval/public/equality_function_registrar.h" #include "eval/public/logical_function_registrar.h" #include "eval/public/portable_cel_function_adapter.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" #include "internal/overflow.h" #include "internal/proto_time_encoding.h" @@ -63,6 +67,8 @@ namespace { using ::cel::BinaryFunctionAdapter; using ::cel::BytesValue; using ::cel::Handle; +using ::cel::ListValue; +using ::cel::MapValue; using ::cel::StringValue; using ::cel::UnaryFunctionAdapter; using ::cel::Value; @@ -371,40 +377,50 @@ const CelList* AppendList(Arena* arena, const CelList* value1, } // Concatenation for string type. -absl::StatusOr> ConcatString( - ValueFactory& factory, const Handle& value1, - const Handle& value2) { +absl::StatusOr> ConcatString(ValueFactory& factory, + const StringValue& value1, + const StringValue& value2) { return factory.CreateUncheckedStringValue( - absl::StrCat(value1->ToString(), value2->ToString())); + absl::StrCat(value1.ToString(), value2.ToString())); } // Concatenation for bytes type. -absl::StatusOr> ConcatBytes( - ValueFactory& factory, const Handle& value1, - const Handle& value2) { +absl::StatusOr> ConcatBytes(ValueFactory& factory, + const BytesValue& value1, + const BytesValue& value2) { return factory.CreateBytesValue( - absl::StrCat(value1->ToString(), value2->ToString())); + absl::StrCat(value1.ToString(), value2.ToString())); } // Concatenation for CelList type. -const CelList* ConcatList(Arena* arena, const CelList* value1, - const CelList* value2) { +absl::StatusOr> ConcatList(ValueFactory& factory, + const ListValue& value1, + const ListValue& value2) { std::vector joined_values; - int size1 = value1->size(); - int size2 = value2->size(); + int size1 = value1.size(); + int size2 = value2.size(); joined_values.reserve(size1 + size2); + Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( + factory.memory_manager()); + + ListValue::GetContext context(factory); for (int i = 0; i < size1; i++) { - joined_values.push_back((*value1).Get(arena, i)); + CEL_ASSIGN_OR_RETURN(Handle elem, value1.Get(context, i)); + joined_values.push_back( + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); } for (int i = 0; i < size2; i++) { - joined_values.push_back((*value2).Get(arena, i)); + CEL_ASSIGN_OR_RETURN(Handle elem, value2.Get(context, i)); + joined_values.push_back( + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); } auto concatenated = Arena::Create(arena, joined_values); - return concatenated; + + return cel::interop_internal::CreateLegacyListValue(concatenated); } // Timestamp @@ -545,9 +561,9 @@ Handle GetMilliseconds(ValueFactory& value_factory, absl::Time timestamp, } Handle CreateDurationFromString(ValueFactory& value_factory, - const Handle& dur_str) { + const StringValue& dur_str) { absl::Duration d; - if (!absl::ParseDuration(dur_str->ToString(), &d)) { + if (!absl::ParseDuration(dur_str.ToString(), &d)) { return value_factory.CreateErrorValue( absl::InvalidArgumentError("String to Duration conversion failed")); } @@ -561,19 +577,19 @@ Handle CreateDurationFromString(ValueFactory& value_factory, return *duration; } -bool StringContains(ValueFactory&, const Handle& value, - const Handle& substr) { - return absl::StrContains(value->ToString(), substr->ToString()); +bool StringContains(ValueFactory&, const StringValue& value, + const StringValue& substr) { + return absl::StrContains(value.ToString(), substr.ToString()); } -bool StringEndsWith(ValueFactory&, const Handle& value, - const Handle& suffix) { - return absl::EndsWith(value->ToString(), suffix->ToString()); +bool StringEndsWith(ValueFactory&, const StringValue& value, + const StringValue& suffix) { + return absl::EndsWith(value.ToString(), suffix.ToString()); } -bool StringStartsWith(ValueFactory&, const Handle& value, - const Handle& prefix) { - return absl::StartsWith(value->ToString(), prefix->ToString()); +bool StringStartsWith(ValueFactory&, const StringValue& value, + const StringValue& prefix) { + return absl::StartsWith(value.ToString(), prefix.ToString()); } absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, @@ -744,42 +760,17 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, return absl::OkStatus(); } -absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - // Basic substring tests (contains, startsWith, endsWith) - for (bool receiver_style : {true, false}) { - CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter&, - const Handle&>:: - CreateDescriptor(builtin::kStringContains, receiver_style), - BinaryFunctionAdapter< - bool, const Handle&, - const Handle&>::WrapFunction(StringContains))); - - CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter&, - const Handle&>:: - CreateDescriptor(builtin::kStringEndsWith, receiver_style), - BinaryFunctionAdapter< - bool, const Handle&, - const Handle&>::WrapFunction(StringEndsWith))); - - CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter&, - const Handle&>:: - CreateDescriptor(builtin::kStringStartsWith, receiver_style), - BinaryFunctionAdapter< - bool, const Handle&, - const Handle&>::WrapFunction(StringStartsWith))); - } - - // matches function if enabled. +// TODO(issues/5): after refactors for the new value type are done, move this +// to a separate build target to enable subset environments to not depend on +// RE2. +absl::Status RegisterRegexFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { if (options.enable_regex) { - auto regex_matches = - [max_size = options.regex_max_program_size]( - ValueFactory& value_factory, const Handle& target, - const Handle& regex) -> Handle { - RE2 re2(regex->ToString()); + auto regex_matches = [max_size = options.regex_max_program_size]( + ValueFactory& value_factory, + const StringValue& target, + const StringValue& regex) -> Handle { + RE2 re2(regex.ToString()); if (max_size > 0 && re2.ProgramSize() > max_size) { return value_factory.CreateErrorValue( absl::InvalidArgumentError("exceeded RE2 max program size")); @@ -789,14 +780,14 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, absl::InvalidArgumentError("invalid regex for match")); } return value_factory.CreateBoolValue( - RE2::PartialMatch(target->ToString(), re2)); + RE2::PartialMatch(target.ToString(), re2)); }; // bind str.matches(re) and matches(str, re) for (bool receiver_style : {true, false}) { using MatchFnAdapter = - BinaryFunctionAdapter, const Handle&, - const Handle&>; + BinaryFunctionAdapter, const StringValue&, + const StringValue&>; CEL_RETURN_IF_ERROR( registry->Register(MatchFnAdapter::CreateDescriptor( builtin::kRegexMatch, receiver_style), @@ -804,20 +795,44 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, } } // if options.enable_regex + return absl::OkStatus(); +} + +absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // Basic substring tests (contains, startsWith, endsWith) + for (bool receiver_style : {true, false}) { + CEL_RETURN_IF_ERROR(registry->Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kStringContains, receiver_style), + BinaryFunctionAdapter:: + WrapFunction(StringContains))); + + CEL_RETURN_IF_ERROR(registry->Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kStringEndsWith, receiver_style), + BinaryFunctionAdapter:: + WrapFunction(StringEndsWith))); + + CEL_RETURN_IF_ERROR(registry->Register( + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kStringStartsWith, receiver_style), + BinaryFunctionAdapter:: + WrapFunction(StringStartsWith))); + } + // string concatenation if enabled if (options.enable_string_concat) { using StrCatFnAdapter = BinaryFunctionAdapter>, - const Handle&, - const Handle&>; + const StringValue&, const StringValue&>; CEL_RETURN_IF_ERROR(registry->Register( StrCatFnAdapter::CreateDescriptor(builtin::kAdd, false), StrCatFnAdapter::WrapFunction(&ConcatString))); using BytesCatFnAdapter = BinaryFunctionAdapter>, - const Handle&, - const Handle&>; + const BytesValue&, const BytesValue&>; CEL_RETURN_IF_ERROR(registry->Register( BytesCatFnAdapter::CreateDescriptor(builtin::kAdd, false), BytesCatFnAdapter::WrapFunction(&ConcatBytes))); @@ -825,8 +840,8 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, // String size auto size_func = [](ValueFactory& value_factory, - const Handle& value) -> Handle { - auto [count, valid] = ::cel::internal::Utf8Validate(value->ToString()); + const StringValue& value) -> Handle { + auto [count, valid] = ::cel::internal::Utf8Validate(value.ToString()); if (!valid) { return value_factory.CreateErrorValue( absl::InvalidArgumentError("invalid utf-8 string")); @@ -837,7 +852,7 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, // receiver style = true/false // Support global and receiver style size() operations on strings. using StrSizeFnAdapter = - UnaryFunctionAdapter, const Handle&>; + UnaryFunctionAdapter, const StringValue&>; CEL_RETURN_IF_ERROR( registry->Register(StrSizeFnAdapter::CreateDescriptor( builtin::kSize, /*receiver_style=*/true), @@ -848,14 +863,12 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, StrSizeFnAdapter::WrapFunction(size_func))); // Bytes size - auto bytes_size_func = [](ValueFactory&, - const Handle& value) -> int64_t { - return value->size(); + auto bytes_size_func = [](ValueFactory&, const BytesValue& value) -> int64_t { + return value.size(); }; // receiver style = true/false // Support global and receiver style size() operations on bytes. - using BytesSizeFnAdapter = - UnaryFunctionAdapter&>; + using BytesSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR( registry->Register(BytesSizeFnAdapter::CreateDescriptor( builtin::kSize, /*receiver_style=*/true), @@ -871,15 +884,12 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kFullYear, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kFullYear, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetFullYear(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetFullYear(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -891,14 +901,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kMonth, true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kMonth, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetMonth(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetMonth(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -910,15 +918,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kDayOfYear, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kDayOfYear, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetDayOfYear(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetDayOfYear(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -930,15 +935,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kDayOfMonth, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kDayOfMonth, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetDayOfMonth(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetDayOfMonth(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -950,14 +952,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kDate, true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kDate, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetDate(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetDate(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -969,15 +969,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kDayOfWeek, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kDayOfWeek, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetDayOfWeek(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetDayOfWeek(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -989,14 +986,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kHours, true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kHours, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetHours(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetHours(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -1008,15 +1003,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kMinutes, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kMinutes, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetMinutes(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetMinutes(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -1028,15 +1020,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kSeconds, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kSeconds, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetSeconds(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetSeconds(value_factory, ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry->Register( @@ -1048,15 +1037,12 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter< - Handle, absl::Time, - const Handle&>::CreateDescriptor(builtin::kMilliseconds, - true), - BinaryFunctionAdapter, absl::Time, - const Handle&>:: + BinaryFunctionAdapter, absl::Time, const StringValue&>:: + CreateDescriptor(builtin::kMilliseconds, true), + BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, - const Handle& tz) -> Handle { - return GetMilliseconds(value_factory, ts, tz->ToString()); + const StringValue& tz) -> Handle { + return GetMilliseconds(value_factory, ts, tz.ToString()); }))); return registry->Register( @@ -1072,23 +1058,23 @@ absl::Status RegisterBytesConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // bytes -> bytes CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: + UnaryFunctionAdapter, Handle>:: CreateDescriptor(builtin::kBytes, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory&, const Handle& value) + UnaryFunctionAdapter, Handle>:: + WrapFunction([](ValueFactory&, Handle value) -> Handle { return value; }))); // string -> bytes return registry->Register( UnaryFunctionAdapter< absl::StatusOr>, - Handle>::CreateDescriptor(builtin::kBytes, false), - UnaryFunctionAdapter>, - Handle>:: - WrapFunction([](ValueFactory& value_factory, - const Handle& value) { - return value_factory.CreateBytesValue(value->ToString()); - })); + const StringValue&>::CreateDescriptor(builtin::kBytes, false), + UnaryFunctionAdapter< + absl::StatusOr>, + const StringValue&>::WrapFunction([](ValueFactory& value_factory, + const StringValue& value) { + return value_factory.CreateBytesValue(value.ToString()); + })); } absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, @@ -1109,13 +1095,13 @@ absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, // string -> double CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kDouble, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory& value_factory, - const Handle& s) -> Handle { + UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( + builtin::kDouble, false), + UnaryFunctionAdapter, const StringValue&>::WrapFunction( + [](ValueFactory& value_factory, + const StringValue& s) -> Handle { double result; - if (absl::SimpleAtod(s->ToString(), &result)) { + if (absl::SimpleAtod(s.ToString(), &result)) { return value_factory.CreateDoubleValue(result); } else { return value_factory.CreateErrorValue(absl::InvalidArgumentError( @@ -1162,13 +1148,13 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, // string -> int CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kInt, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory& value_factory, - const Handle& s) -> Handle { + UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( + builtin::kInt, false), + UnaryFunctionAdapter, const StringValue&>::WrapFunction( + [](ValueFactory& value_factory, + const StringValue& s) -> Handle { int64_t result; - if (!absl::SimpleAtoi(s->ToString(), &result)) { + if (!absl::SimpleAtoi(s.ToString(), &result)) { return value_factory.CreateErrorValue( absl::InvalidArgumentError("cannot convert string to int")); } @@ -1204,12 +1190,12 @@ absl::Status RegisterStringConversionFunctions( } CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kString, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory& value_factory, - const Handle& value) -> Handle { - auto handle_or = value_factory.CreateStringValue(value->ToString()); + UnaryFunctionAdapter, const BytesValue&>::CreateDescriptor( + builtin::kString, false), + UnaryFunctionAdapter, const BytesValue&>::WrapFunction( + [](ValueFactory& value_factory, + const BytesValue& value) -> Handle { + auto handle_or = value_factory.CreateStringValue(value.ToString()); if (!handle_or.ok()) { return value_factory.CreateErrorValue(handle_or.status()); } @@ -1239,10 +1225,10 @@ absl::Status RegisterStringConversionFunctions( // string -> string CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: + UnaryFunctionAdapter, Handle>:: CreateDescriptor(builtin::kString, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory&, const Handle& value) + UnaryFunctionAdapter, Handle>:: + WrapFunction([](ValueFactory&, Handle value) -> Handle { return value; }))); // uint -> string @@ -1314,13 +1300,13 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, // string -> uint CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kUint, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction([](ValueFactory& value_factory, - const Handle& s) -> Handle { + UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( + builtin::kUint, false), + UnaryFunctionAdapter, const StringValue&>::WrapFunction( + [](ValueFactory& value_factory, + const StringValue& s) -> Handle { uint64_t result; - if (!absl::SimpleAtoi(s->ToString(), &result)) { + if (!absl::SimpleAtoi(s.ToString(), &result)) { return value_factory.CreateErrorValue( absl::InvalidArgumentError("doesn't convert to a string")); } @@ -1343,10 +1329,10 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, // duration() conversion from string. CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kDuration, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction(CreateDurationFromString))); + UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( + builtin::kDuration, false), + UnaryFunctionAdapter, const StringValue&>::WrapFunction( + CreateDurationFromString))); // dyn() identity function. // TODO(issues/102): strip dyn() function references at type-check time. @@ -1377,27 +1363,25 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, bool enable_timestamp_duration_overflow_errors = options.enable_timestamp_duration_overflow_errors; CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter, const Handle&>:: - CreateDescriptor(builtin::kTimestamp, false), - UnaryFunctionAdapter, const Handle&>:: - WrapFunction( - [=](ValueFactory& value_factory, - const Handle& time_str) -> Handle { - absl::Time ts; - if (!absl::ParseTime(absl::RFC3339_full, time_str->ToString(), - &ts, nullptr)) { - return value_factory.CreateErrorValue( - absl::InvalidArgumentError( - "String to Timestamp conversion failed")); - } - if (enable_timestamp_duration_overflow_errors) { - if (ts < absl::UniversalEpoch() || ts > kMaxTime) { - return value_factory.CreateErrorValue( - absl::OutOfRangeError("timestamp overflow")); - } - } - return cel::interop_internal::CreateTimestampValue(ts); - }))); + UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( + builtin::kTimestamp, false), + UnaryFunctionAdapter, const StringValue&>::WrapFunction( + [=](ValueFactory& value_factory, + const StringValue& time_str) -> Handle { + absl::Time ts; + if (!absl::ParseTime(absl::RFC3339_full, time_str.ToString(), &ts, + nullptr)) { + return value_factory.CreateErrorValue(absl::InvalidArgumentError( + "String to Timestamp conversion failed")); + } + if (enable_timestamp_duration_overflow_errors) { + if (ts < absl::UniversalEpoch() || ts > kMaxTime) { + return value_factory.CreateErrorValue( + absl::OutOfRangeError("timestamp overflow")); + } + } + return cel::interop_internal::CreateTimestampValue(ts); + }))); return RegisterUintConversionFunctions(registry, options); } @@ -1611,6 +1595,47 @@ absl::Status RegisterTimeFunctions(CelFunctionRegistry* registry, return absl::OkStatus(); } +int64_t MapSizeImpl(ValueFactory&, const MapValue& value) { + return value.size(); +} + +int64_t ListSizeImpl(ValueFactory&, const ListValue& value) { + return value.size(); +} + +absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // receiver style = true/false + // Support both the global and receiver style size() for lists and maps. + for (bool receiver_style : {true, false}) { + CEL_RETURN_IF_ERROR(registry->Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kSize, receiver_style), + UnaryFunctionAdapter::WrapFunction( + ListSizeImpl))); + + CEL_RETURN_IF_ERROR(registry->Register( + UnaryFunctionAdapter::CreateDescriptor( + builtin::kSize, receiver_style), + UnaryFunctionAdapter::WrapFunction( + MapSizeImpl))); + } + + if (options.enable_list_concat) { + CEL_RETURN_IF_ERROR(registry->Register( + BinaryFunctionAdapter>, const ListValue&, + const ListValue&>::CreateDescriptor(builtin::kAdd, + false), + BinaryFunctionAdapter>, const ListValue&, + const ListValue&>::WrapFunction(ConcatList))); + } + + return registry->Register(PortableBinaryFunctionAdapter< + const CelList*, const CelList*, + const CelList*>::Create(builtin::kRuntimeListAppend, + false, AppendList)); +} + } // namespace absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, @@ -1624,63 +1649,30 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, &RegisterConversionFunctions, &RegisterTimeFunctions, &RegisterStringFunctions, + &RegisterRegexFunctions, + &RegisterSetMembershipFunctions, + &RegisterContainerFunctions, }, options)); - // List size - auto list_size_func = [](Arena*, const CelList* cel_list) -> int64_t { - return (*cel_list).size(); - }; - // receiver style = true/false - // Support both the global and receiver style size() for lists. - auto status = registry->Register( - PortableUnaryFunctionAdapter::Create( - builtin::kSize, true, list_size_func)); - if (!status.ok()) return status; - status = registry->Register( - PortableUnaryFunctionAdapter::Create( - builtin::kSize, false, list_size_func)); - if (!status.ok()) return status; - - // Map size - auto map_size_func = [](Arena*, const CelMap* cel_map) -> int64_t { - return (*cel_map).size(); - }; - // receiver style = true/false - status = registry->Register( - PortableUnaryFunctionAdapter::Create( - builtin::kSize, true, map_size_func)); - if (!status.ok()) return status; - status = registry->Register( - PortableUnaryFunctionAdapter::Create( - builtin::kSize, false, map_size_func)); - if (!status.ok()) return status; - - // Register set membership tests with the 'in' operator and its variants. - status = RegisterSetMembershipFunctions(registry, options); - if (!status.ok()) return status; - - if (options.enable_list_concat) { - status = registry->Register( - PortableBinaryFunctionAdapter::Create(builtin::kAdd, - false, - ConcatList)); - if (!status.ok()) return status; - } - - status = - registry->Register(PortableBinaryFunctionAdapter< - const CelList*, const CelList*, - const CelList*>::Create(builtin::kRuntimeListAppend, - false, AppendList)); - if (!status.ok()) return status; - return registry->Register( - PortableUnaryFunctionAdapter::Create( - builtin::kType, false, - [](Arena*, CelValue value) -> CelValue::CelTypeHolder { - return value.ObtainCelType().CelTypeOrDie(); + UnaryFunctionAdapter< + Handle, const Handle&>::CreateDescriptor(builtin::kType, + false), + UnaryFunctionAdapter, const Handle&>::WrapFunction( + [](ValueFactory& factory, const Handle& value) { + // TODO(issues/5): legacy types don't interop with type values + // from factory. This should simply be: + // + // return factory.CreateTypeValue(value->type()); + Arena* arena = + cel::extensions::ProtoMemoryManager::CastToProtoArena( + factory.memory_manager()); + CelValue legacy_value = + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, + value); + return cel::interop_internal::CreateTypeValueFromView( + legacy_value.ObtainCelType().CelTypeOrDie().value()); })); } From 477e7e88ba3a75fe3bd8920a45d50ba05fcfe621 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 24 Apr 2023 19:09:42 +0000 Subject: [PATCH 223/303] Simplify interface for Qualified reference resolver. Follow-up will adapt it to implement a generic interface. No functional changes. PiperOrigin-RevId: 526717877 --- base/internal/ast_impl.h | 5 + eval/compiler/BUILD | 7 +- eval/compiler/flat_expr_builder.cc | 5 +- eval/compiler/qualified_reference_resolver.cc | 28 +- eval/compiler/qualified_reference_resolver.h | 11 +- .../qualified_reference_resolver_test.cc | 252 +++++++++--------- 6 files changed, 149 insertions(+), 159 deletions(-) diff --git a/base/internal/ast_impl.h b/base/internal/ast_impl.h index bf1e3a1cc..04a48447f 100644 --- a/base/internal/ast_impl.h +++ b/base/internal/ast_impl.h @@ -84,6 +84,11 @@ class AstImpl : public Ast { const absl::flat_hash_map& reference_map() const { return reference_map_; } + + absl::flat_hash_map& reference_map() { + return reference_map_; + } + const absl::flat_hash_map& type_map() const { return type_map_; } diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 329e285b3..901a19e32 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -192,7 +192,9 @@ cc_library( ], deps = [ ":resolver", + "//base:ast", "//base:ast_internal", + "//base/internal:ast_impl", "//eval/eval:const_value_step", "//eval/eval:expression_build_warning", "//eval/public:ast_rewrite_native", @@ -204,7 +206,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", ], ) @@ -231,16 +232,20 @@ cc_test( ], deps = [ ":qualified_reference_resolver", + "//base:ast", + "//base/internal:ast_impl", "//eval/public:builtin_func_registrar", "//eval/public:cel_builtins", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", "//extensions/protobuf:ast_converters", + "//internal:casts", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 3c88a6fdb..47ecfa4ce 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1266,9 +1266,8 @@ FlatExprBuilder::CreateExpressionImpl( // available, we can skip the reference resolve step here if it's already // done. if (rewrites_enabled) { - absl::StatusOr rewritten = ResolveReferences( - &ast_impl.reference_map(), resolver, &ast_impl.source_info(), - warnings_builder, &ast_impl.root_expr()); + absl::StatusOr rewritten = + ResolveReferences(resolver, warnings_builder, ast_impl); if (!rewritten.ok()) { return rewritten.status(); } diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 9b5fb93f5..b11e70c22 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -11,6 +11,8 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "base/ast.h" +#include "base/internal/ast_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/expression_build_warning.h" #include "eval/public/ast_rewrite_native.h" @@ -76,7 +78,7 @@ absl::optional BestOverloadMatch(const Resolver& resolver, class ReferenceResolver : public cel::ast::internal::AstRewriterBase { public: ReferenceResolver( - const absl::flat_hash_map* reference_map, + const absl::flat_hash_map& reference_map, const Resolver& resolver, BuilderWarnings& warnings) : reference_map_(reference_map), resolver_(resolver), @@ -247,12 +249,8 @@ class ReferenceResolver : public cel::ast::internal::AstRewriterBase { // // Returns nullptr if no reference is available. const Reference* GetReferenceForId(int64_t expr_id) { - if (reference_map_ == nullptr) { - return nullptr; - } - - auto iter = reference_map_->find(expr_id); - if (iter == reference_map_->end()) { + auto iter = reference_map_.find(expr_id); + if (iter == reference_map_.end()) { return nullptr; } if (expr_id == 0) { @@ -265,7 +263,7 @@ class ReferenceResolver : public cel::ast::internal::AstRewriterBase { return &iter->second; } - const absl::flat_hash_map* reference_map_; + const absl::flat_hash_map& reference_map_; const Resolver& resolver_; BuilderWarnings& warnings_; absl::flat_hash_set rewritten_reference_; @@ -273,17 +271,15 @@ class ReferenceResolver : public cel::ast::internal::AstRewriterBase { } // namespace -absl::StatusOr ResolveReferences( - const absl::flat_hash_map* - reference_map, - const Resolver& resolver, const cel::ast::internal::SourceInfo* source_info, - BuilderWarnings& warnings, cel::ast::internal::Expr* expr) { - ReferenceResolver ref_resolver(reference_map, resolver, warnings); +absl::StatusOr ResolveReferences(const Resolver& resolver, + BuilderWarnings& warnings, + cel::ast::internal::AstImpl& ast) { + ReferenceResolver ref_resolver(ast.reference_map(), resolver, warnings); // Rewriting interface doesn't support failing mid traverse propagate first // error encountered if fail fast enabled. - bool was_rewritten = - cel::ast::internal::AstRewrite(expr, source_info, &ref_resolver); + bool was_rewritten = cel::ast::internal::AstRewrite( + &ast.root_expr(), &ast.source_info(), &ref_resolver); if (warnings.fail_immediately() && !warnings.warnings().empty()) { return warnings.warnings().front(); } diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index b105a78eb..8851c2e97 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -3,9 +3,8 @@ #include -#include "google/protobuf/map.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" +#include "base/ast.h" #include "base/ast_internal.h" #include "eval/compiler/resolver.h" #include "eval/eval/expression_build_warning.h" @@ -21,11 +20,9 @@ namespace google::api::expr::runtime { // Will warn or return a non-ok status if references can't be resolved (no // function overload could match a call) or are inconsistnet (reference map // points to an expr node that isn't a reference). -absl::StatusOr ResolveReferences( - const absl::flat_hash_map* - reference_map, - const Resolver& resolver, const cel::ast::internal::SourceInfo* source_info, - BuilderWarnings& warnings, cel::ast::internal::Expr* expr); +absl::StatusOr ResolveReferences(const Resolver& resolver, + BuilderWarnings& warnings, + cel::ast::internal::AstImpl& ast); } // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index c42521760..fe7100673 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -1,18 +1,23 @@ #include "eval/compiler/qualified_reference_resolver.h" #include +#include #include #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/types/optional.h" +#include "base/ast.h" +#include "base/internal/ast_impl.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "extensions/protobuf/ast_converters.h" +#include "internal/casts.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" @@ -21,6 +26,8 @@ namespace google::api::expr::runtime { namespace { +using ::cel::ast::Ast; +using ::cel::ast::internal::AstImpl; using ::cel::ast::internal::Expr; using ::cel::ast::internal::Reference; using ::cel::ast::internal::SourceInfo; @@ -78,25 +85,23 @@ MATCHER_P(StatusCodeIs, x, "") { return status.code() == x; } -Expr ParseTestProto(const std::string& pb) { +std::unique_ptr ParseTestProto(const std::string& pb) { google::api::expr::v1alpha1::Expr expr; EXPECT_TRUE(google::protobuf::TextFormat::ParseFromString(pb, &expr)); - return ConvertProtoExprToNative(expr).value(); + return absl::WrapUnique(cel::internal::down_cast( + cel::extensions::CreateAstFromParsedExpr(expr).value().release())); } TEST(ResolveReferences, Basic) { - Expr expr = ParseTestProto(kExpr); - SourceInfo source_info; - absl::flat_hash_map reference_map; - reference_map[2].set_name("foo.bar.var1"); - reference_map[5].set_name("bar.foo.var2"); + std::unique_ptr expr_ast = ParseTestProto(kExpr); + expr_ast->reference_map()[2].set_name("foo.bar.var1"); + expr_ast->reference_map()[5].set_name("bar.foo.var2"); BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( @@ -113,44 +118,39 @@ TEST(ResolveReferences, Basic) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } TEST(ResolveReferences, ReturnsFalseIfNoChanges) { - Expr expr = ParseTestProto(kExpr); - SourceInfo source_info; - absl::flat_hash_map reference_map; + std::unique_ptr expr_ast = ParseTestProto(kExpr); BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); // reference to the same name also doesn't count as a rewrite. - reference_map[4].set_name("foo"); - reference_map[7].set_name("bar"); + expr_ast->reference_map()[4].set_name("foo"); + expr_ast->reference_map()[7].set_name("bar"); - result = ResolveReferences(&reference_map, registry, &source_info, warnings, - &expr); + result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, NamespacedIdent) { - Expr expr = ParseTestProto(kExpr); + std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[2].set_name("foo.bar.var1"); - reference_map[7].set_name("namespace_x.bar"); + expr_ast->reference_map()[2].set_name("foo.bar.var1"); + expr_ast->reference_map()[7].set_name("namespace_x.bar"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( @@ -180,11 +180,12 @@ TEST(ResolveReferences, NamespacedIdent) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } TEST(ResolveReferences, WarningOnPresenceTest) { - Expr expr = ParseTestProto(R"( + std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 1 select_expr { field: "var1" @@ -199,18 +200,16 @@ TEST(ResolveReferences, WarningOnPresenceTest) { } } } - })"); + })pb"); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[1].set_name("foo.bar.var1"); + expr_ast->reference_map()[1].set_name("foo.bar.var1"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( @@ -249,21 +248,19 @@ constexpr char kEnumExpr[] = R"( )"; TEST(ResolveReferences, EnumConstReferenceUsed) { - Expr expr = ParseTestProto(kEnumExpr); + std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[2].set_name("foo.bar.var1"); - reference_map[5].set_name("bar.foo.Enum.ENUM_VAL1"); - reference_map[5].mutable_value().set_int64_value(9); + expr_ast->reference_map()[2].set_name("foo.bar.var1"); + expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->reference_map()[5].mutable_value().set_int64_value(9); BuilderWarnings warnings; - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -281,26 +278,25 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } TEST(ResolveReferences, EnumConstReferenceUsedSelect) { - Expr expr = ParseTestProto(kEnumExpr); + std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[2].set_name("foo.bar.var1"); - reference_map[2].mutable_value().set_int64_value(2); - reference_map[5].set_name("bar.foo.Enum.ENUM_VAL1"); - reference_map[5].mutable_value().set_int64_value(9); + expr_ast->reference_map()[2].set_name("foo.bar.var1"); + expr_ast->reference_map()[2].mutable_value().set_int64_value(2); + expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->reference_map()[5].mutable_value().set_int64_value(9); BuilderWarnings warnings; - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -318,25 +314,24 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } TEST(ResolveReferences, ConstReferenceSkipped) { - Expr expr = ParseTestProto(kExpr); + std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[2].set_name("foo.bar.var1"); - reference_map[2].mutable_value().set_bool_value(true); - reference_map[5].set_name("bar.foo.var2"); + expr_ast->reference_map()[2].set_name("foo.bar.var1"); + expr_ast->reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->reference_map()[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -366,7 +361,8 @@ TEST(ResolveReferences, ConstReferenceSkipped) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } constexpr char kExtensionAndExpr[] = R"( @@ -388,10 +384,9 @@ call_expr { })"; TEST(ResolveReferences, FunctionReferenceBasic) { - Expr expr = ParseTestProto(kExtensionAndExpr); + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction( CelFunctionDescriptor("boolean_and", false, @@ -402,27 +397,26 @@ TEST(ResolveReferences, FunctionReferenceBasic) { CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); BuilderWarnings warnings; - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { - Expr expr = ParseTestProto(kExtensionAndExpr); + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); BuilderWarnings warnings; - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), @@ -430,7 +424,7 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { } TEST(ResolveReferences, SpecialBuiltinsNotWarned) { - Expr expr = ParseTestProto(R"( + std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 1 call_expr { function: "*" @@ -442,24 +436,22 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { id: 3 const_expr { bool_value: false } } - })"); + })pb"); SourceInfo source_info; std::vector special_builtins{builtin::kAnd, builtin::kOr, builtin::kTernary, builtin::kIndex}; for (const char* builtin_fn : special_builtins) { - absl::flat_hash_map reference_map; // Builtins aren't in the function registry. CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); BuilderWarnings warnings; - reference_map[1].mutable_overload_id().push_back( + expr_ast->reference_map()[1].mutable_overload_id().push_back( absl::StrCat("builtin.", builtin_fn)); - expr.mutable_call_expr().set_function(builtin_fn); + expr_ast->root_expr().mutable_call_expr().set_function(builtin_fn); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -468,18 +460,16 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { TEST(ResolveReferences, FunctionReferenceMissingOverloadDetectedAndMissingReference) { - Expr expr = ParseTestProto(kExtensionAndExpr); + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); BuilderWarnings warnings; - reference_map[1].set_name("udf_boolean_and"); + expr_ast->reference_map()[1].set_name("udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( @@ -492,36 +482,33 @@ TEST(ResolveReferences, } TEST(ResolveReferences, EmulatesEagerFailing) { - Expr expr = ParseTestProto(kExtensionAndExpr); + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); BuilderWarnings warnings(/*fail_eagerly=*/true); - reference_map[1].set_name("udf_boolean_and"); + expr_ast->reference_map()[1].set_name("udf_boolean_and"); EXPECT_THAT( - ResolveReferences(&reference_map, registry, &source_info, warnings, - &expr), + ResolveReferences(registry, warnings, *expr_ast), StatusIs(absl::StatusCode::kInvalidArgument, "Reference map doesn't provide overloads for boolean_and")); } TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { - Expr expr = ParseTestProto(kExtensionAndExpr); + std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[2].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[2].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), @@ -547,20 +534,20 @@ call_expr { })"; TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { - Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -568,18 +555,18 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { TEST(ResolveReferences, FunctionReferenceWithTargetNoChangeMissingOverloadDetected) { - Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), @@ -587,20 +574,20 @@ TEST(ResolveReferences, } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { - Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.boolean_and", false, {CelValue::Type::kBool}))); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -615,17 +602,19 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { } )pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); EXPECT_THAT(warnings.warnings(), IsEmpty()); } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunctionInContainer) { - Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); BuilderWarnings warnings; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( @@ -633,8 +622,7 @@ TEST(ResolveReferences, CelTypeRegistry type_registry; Resolver registry("com.google", func_registry.InternalGetRegistry(), &type_registry); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -649,7 +637,8 @@ TEST(ResolveReferences, } )pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); EXPECT_THAT(warnings.warnings(), IsEmpty()); } @@ -680,10 +669,10 @@ call_expr { })"; TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { - Expr expr = ParseTestProto(kReceiverCallHasExtensionAndExpr); + std::unique_ptr expr_ast = + ParseTestProto(kReceiverCallHasExtensionAndExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( @@ -692,17 +681,18 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { "ext.option.boolean_and", true, {CelValue::Type::kBool}))); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[1].mutable_overload_id().push_back("udf_boolean_and"); + expr_ast->reference_map()[1].mutable_overload_id().push_back( + "udf_boolean_and"); - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); // The target is unchanged because it is a test_only select. google::api::expr::v1alpha1::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(kReceiverCallHasExtensionAndExpr, &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); EXPECT_THAT(warnings.warnings(), IsEmpty()); } @@ -774,23 +764,21 @@ comprehension_expr: { } )"; TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { - Expr expr = ParseTestProto(kComprehensionExpr); + std::unique_ptr expr_ast = ParseTestProto(kComprehensionExpr); SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[3].set_name("ENUM"); - reference_map[3].mutable_value().set_int64_value(2); - reference_map[7].set_name("ENUM"); - reference_map[7].mutable_value().set_int64_value(2); + expr_ast->reference_map()[3].set_name("ENUM"); + expr_ast->reference_map()[3].mutable_value().set_int64_value(2); + expr_ast->reference_map()[7].set_name("ENUM"); + expr_ast->reference_map()[7].mutable_value().set_int64_value(2); BuilderWarnings warnings; - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); google::api::expr::v1alpha1::Expr expected_expr; @@ -867,13 +855,14 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { } })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); } TEST(ResolveReferences, ReferenceToId0Warns) { // ID 0 is unsupported since it is not normally used by parsers and is // ambiguous as an intentional ID or default for unset field. - Expr expr = ParseTestProto(R"pb( + std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 0 select_expr { operand { @@ -885,16 +874,14 @@ TEST(ResolveReferences, ReferenceToId0Warns) { SourceInfo source_info; - absl::flat_hash_map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); CelTypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), &type_registry); - reference_map[0].set_name("pkg.var"); + expr_ast->reference_map()[0].set_name("pkg.var"); BuilderWarnings warnings; - auto result = ResolveReferences(&reference_map, registry, &source_info, - warnings, &expr); + auto result = ResolveReferences(registry, warnings, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); google::api::expr::v1alpha1::Expr expected_expr; @@ -908,7 +895,8 @@ TEST(ResolveReferences, ReferenceToId0Warns) { field: "var" })pb", &expected_expr); - EXPECT_EQ(expr, ConvertProtoExprToNative(expected_expr).value()); + EXPECT_EQ(expr_ast->root_expr(), + ConvertProtoExprToNative(expected_expr).value()); EXPECT_THAT( warnings.warnings(), Contains(StatusIs( From 0f1f3b67fdc8131e4ccedeffaaaf7102c2d57313 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 24 Apr 2023 19:28:21 +0000 Subject: [PATCH 224/303] Update const folding to not change accu_init for map/filter comprehensions. PiperOrigin-RevId: 526722831 --- eval/compiler/constant_folding.cc | 9 ++++++++ eval/tests/benchmark_test.cc | 35 +++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 5452aacc2..0a0499dd5 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -253,10 +253,19 @@ class ConstantFoldingTransform { transform_.Transform(expr_.list_expr().elements()[i], element) && all_constant; } + if (!all_constant) { return false; } + if (list_size == 0) { + // TODO(issues/5): need a more robust fix to support generic + // comprehensions, but this will allow comprehension list append + // optimization to work to prevent quadratic memory consumption for + // map/filter. + return false; + } + // create a constant list value std::vector> values(list_size); for (int i = 0; i < list_size; i++) { diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index b0cf09aad..647823618 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -922,6 +922,41 @@ void BM_ListComprehension_Trace(benchmark::State& state) { BENCHMARK(BM_ListComprehension_Trace)->Range(1, 1 << 16); +void BM_ListComprehension_Opt(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("list.map(x, x * 2)")); + + int len = state.range(0); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; i++) { + list.push_back(CelValue::CreateInt64(1)); + } + + ContainerBackedListImpl cel_list(std::move(list)); + activation.InsertValue("list", CelValue::CreateList(&cel_list)); + InterpreterOptions options; + options.constant_arena = &arena; + options.constant_folding = true; + options.comprehension_max_iterations = 10000000; + options.enable_comprehension_list_append = true; + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + ASSERT_EQ(result.ListOrDie()->size(), len); + } +} + +BENCHMARK(BM_ListComprehension_Opt)->Range(1, 1 << 16); + void BM_ComprehensionCpp(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; From 3a4d58b6cac36eb0a9246e8c7a5565144c68ad3d Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 25 Apr 2023 22:17:58 +0000 Subject: [PATCH 225/303] Add JSON-like wrappers to builtin type provider PiperOrigin-RevId: 527089171 --- base/type_provider.cc | 14 +++++- base/type_provider_test.cc | 94 +++++++++++++++++++++++++++++++++----- 2 files changed, 95 insertions(+), 13 deletions(-) diff --git a/base/type_provider.cc b/base/type_provider.cc index d7d33ee97..c60e73980 100644 --- a/base/type_provider.cc +++ b/base/type_provider.cc @@ -45,8 +45,11 @@ class BuiltinTypeProvider final : public TypeProvider { {"google.protobuf.Duration", GetDurationType}, {"google.protobuf.Timestamp", GetTimestampType}, {"list", GetListType}, + {"google.protobuf.ListValue", GetListType}, {"map", GetMapType}, + {"google.protobuf.Struct", GetStructType}, {"type", GetTypeType}, + {"google.protobuf.Value", GetValueType}, {"google.protobuf.BoolValue", GetBoolWrapperType}, {"google.protobuf.BytesValue", GetBytesWrapperType}, {"google.protobuf.DoubleValue", GetDoubleWrapperType}, @@ -156,11 +159,20 @@ class BuiltinTypeProvider final : public TypeProvider { return HandleFactory::Make(); } + static absl::StatusOr> GetStructType(TypeFactory& type_factory) { + return type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetDynType()); + } + static absl::StatusOr> GetTypeType(TypeFactory& type_factory) { return type_factory.GetTypeType(); } - std::array types_; + static absl::StatusOr> GetValueType(TypeFactory& type_factory) { + return type_factory.GetDynType(); + } + + std::array types_; }; } // namespace diff --git a/base/type_provider_test.cc b/base/type_provider_test.cc index c5060d037..ec65eb308 100644 --- a/base/type_provider_test.cc +++ b/base/type_provider_test.cc @@ -14,6 +14,9 @@ #include "base/type_provider.h" +#include + +#include "base/internal/memory_manager_testing.h" #include "base/memory_manager.h" #include "base/type_factory.h" #include "internal/testing.h" @@ -25,22 +28,50 @@ using testing::Eq; using testing::Optional; using cel::internal::IsOkAndHolds; -TEST(BuiltinTypeProvider, ProvidesBoolWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +class BuiltinTypeProviderTest + : public testing::TestWithParam { + protected: + void SetUp() override { + if (GetParam() == base_internal::MemoryManagerTestMode::kArena) { + memory_manager_ = ArenaMemoryManager::Default(); + } + } + + void TearDown() override { + if (GetParam() == base_internal::MemoryManagerTestMode::kArena) { + memory_manager_.reset(); + } + } + + MemoryManager& memory_manager() const { + switch (GetParam()) { + case base_internal::MemoryManagerTestMode::kGlobal: + return MemoryManager::Global(); + case base_internal::MemoryManagerTestMode::kArena: + return *memory_manager_; + } + } + + private: + std::unique_ptr memory_manager_; +}; + +TEST_P(BuiltinTypeProviderTest, ProvidesBoolWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.BoolValue"), IsOkAndHolds(Optional(Eq(type_factory.GetBoolWrapperType())))); } -TEST(BuiltinTypeProvider, ProvidesBytesWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(BuiltinTypeProviderTest, ProvidesBytesWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.BytesValue"), IsOkAndHolds(Optional(Eq(type_factory.GetBytesWrapperType())))); } -TEST(BuiltinTypeProvider, ProvidesDoubleWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(BuiltinTypeProviderTest, ProvidesDoubleWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.FloatValue"), IsOkAndHolds(Optional(Eq(type_factory.GetDoubleWrapperType())))); @@ -49,8 +80,8 @@ TEST(BuiltinTypeProvider, ProvidesDoubleWrapperType) { IsOkAndHolds(Optional(Eq(type_factory.GetDoubleWrapperType())))); } -TEST(BuiltinTypeProvider, ProvidesIntWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(BuiltinTypeProviderTest, ProvidesIntWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.Int32Value"), IsOkAndHolds(Optional(Eq(type_factory.GetIntWrapperType())))); @@ -59,15 +90,15 @@ TEST(BuiltinTypeProvider, ProvidesIntWrapperType) { IsOkAndHolds(Optional(Eq(type_factory.GetIntWrapperType())))); } -TEST(BuiltinTypeProvider, ProvidesStringWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(BuiltinTypeProviderTest, ProvidesStringWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType( type_factory, "google.protobuf.StringValue"), IsOkAndHolds(Optional(Eq(type_factory.GetStringWrapperType())))); } -TEST(BuiltinTypeProvider, ProvidesUintWrapperType) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(BuiltinTypeProviderTest, ProvidesUintWrapperType) { + TypeFactory type_factory(memory_manager()); ASSERT_THAT(TypeProvider::Builtin().ProvideType( type_factory, "google.protobuf.UInt32Value"), IsOkAndHolds(Optional(Eq(type_factory.GetUintWrapperType())))); @@ -76,5 +107,44 @@ TEST(BuiltinTypeProvider, ProvidesUintWrapperType) { IsOkAndHolds(Optional(Eq(type_factory.GetUintWrapperType())))); } +TEST_P(BuiltinTypeProviderTest, ProvidesValueWrapperType) { + TypeFactory type_factory(memory_manager()); + ASSERT_THAT(TypeProvider::Builtin().ProvideType(type_factory, + "google.protobuf.Value"), + IsOkAndHolds(Optional(Eq(type_factory.GetDynType())))); +} + +TEST_P(BuiltinTypeProviderTest, ProvidesListWrapperType) { + TypeFactory type_factory(memory_manager()); + ASSERT_OK_AND_ASSIGN(auto list_type, + TypeProvider::Builtin().ProvideType( + type_factory, "google.protobuf.ListValue")); + ASSERT_TRUE(list_type.has_value()); + EXPECT_TRUE((*list_type)->Is()); + EXPECT_TRUE(list_type->As()->element()->Is()); +} + +TEST_P(BuiltinTypeProviderTest, ProvidesStructWrapperType) { + TypeFactory type_factory(memory_manager()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + TypeProvider::Builtin().ProvideType( + type_factory, "google.protobuf.Struct")); + ASSERT_TRUE(struct_type.has_value()); + EXPECT_TRUE((*struct_type)->Is()); + EXPECT_TRUE(struct_type->As()->key()->Is()); + EXPECT_TRUE(struct_type->As()->value()->Is()); +} + +TEST_P(BuiltinTypeProviderTest, DoesNotProvide) { + TypeFactory type_factory(memory_manager()); + ASSERT_THAT( + TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.Api"), + IsOkAndHolds(Eq(absl::nullopt))); +} + +INSTANTIATE_TEST_SUITE_P(BuiltinTypeProviderTest, BuiltinTypeProviderTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeName); + } // namespace } // namespace cel From 621b90c77db2342952035d6bc302f90738cae5b2 Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 25 Apr 2023 22:51:58 +0000 Subject: [PATCH 226/303] Minimize destruct overhead of empty AttributeTrail values PiperOrigin-RevId: 527097140 --- eval/eval/attribute_trail.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index b2922b8b5..4afbc10eb 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -28,7 +28,7 @@ namespace google::api::expr::runtime { // or supported. class AttributeTrail { public: - AttributeTrail() = default; + AttributeTrail() : attribute_(absl::nullopt) {} AttributeTrail(google::api::expr::v1alpha1::Expr root, cel::MemoryManager& manager); From 21505e7309ccb764f0ba04385513f2db687a5af2 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Wed, 26 Apr 2023 18:04:00 +0000 Subject: [PATCH 227/303] Introduce interface for AST transformer. Move the reference resolve step to an implementation of this interface. PiperOrigin-RevId: 527317028 --- eval/compiler/BUILD | 15 ++++- eval/compiler/flat_expr_builder.cc | 22 ++----- eval/compiler/flat_expr_builder.h | 18 ++---- eval/compiler/flat_expr_builder_extensions.h | 61 +++++++++++++++++++ eval/compiler/flat_expr_builder_test.cc | 27 ++++++-- eval/compiler/qualified_reference_resolver.cc | 26 ++++++++ eval/compiler/qualified_reference_resolver.h | 12 ++++ eval/public/BUILD | 1 + .../portable_cel_expr_builder_factory.cc | 7 ++- 9 files changed, 151 insertions(+), 38 deletions(-) create mode 100644 eval/compiler/flat_expr_builder_extensions.h diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 901a19e32..b81969900 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -6,6 +6,17 @@ licenses(["notice"]) exports_files(["LICENSE"]) +cc_library( + name = "flat_expr_builder_extensions", + hdrs = ["flat_expr_builder_extensions.h"], + deps = [ + ":resolver", + "//base:ast", + "//eval/eval:expression_build_warning", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "flat_expr_builder", srcs = [ @@ -16,7 +27,7 @@ cc_library( ], deps = [ ":constant_folding", - ":qualified_reference_resolver", + ":flat_expr_builder_extensions", ":resolver", "//base:ast", "//base:ast_internal", @@ -72,6 +83,7 @@ cc_test( ], deps = [ ":flat_expr_builder", + ":qualified_reference_resolver", "//eval/eval:expression_build_warning", "//eval/public:activation", "//eval/public:builtin_func_registrar", @@ -191,6 +203,7 @@ cc_library( "qualified_reference_resolver.h", ], deps = [ + ":flat_expr_builder_extensions", ":resolver", "//base:ast", "//base:ast_internal", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 47ecfa4ce..8a75cca68 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -43,7 +43,7 @@ #include "base/internal/ast_impl.h" #include "base/values/string_value.h" #include "eval/compiler/constant_folding.h" -#include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" @@ -1250,6 +1250,9 @@ FlatExprBuilder::CreateExpressionImpl( GetTypeRegistry(), options_.enable_qualified_type_identifiers); absl::flat_hash_map> constant_idents; + + PlannerContext extension_context(resolver, warnings_builder); + auto& ast_impl = AstImpl::CastFromPublicAst(ast); const cel::ast::internal::Expr* effective_expr = &ast_impl.root_expr(); @@ -1258,21 +1261,8 @@ FlatExprBuilder::CreateExpressionImpl( absl::StrCat("Invalid expression container: '", container(), "'")); } - // transformed expression preserving expression IDs - bool rewrites_enabled = enable_qualified_identifier_rewrites_ || - !ast_impl.reference_map().empty(); - // TODO(issues/98): A type checker may perform these rewrites, but there - // currently isn't a signal to expose that in an expression. If that becomes - // available, we can skip the reference resolve step here if it's already - // done. - if (rewrites_enabled) { - absl::StatusOr rewritten = - ResolveReferences(resolver, warnings_builder, ast_impl); - if (!rewritten.ok()) { - return rewritten.status(); - } - // TODO(issues/99): we could setup a check step here that confirms all of - // references are defined before actually evaluating. + for (const std::unique_ptr& transform : ast_transforms_) { + CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, ast_impl)); } cel::ast::internal::Expr const_fold_buffer; diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 36104ddaa..ece6a25dc 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -18,12 +18,14 @@ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #include +#include #include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/statusor.h" #include "base/ast.h" +#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/public/cel_expression.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -59,15 +61,8 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_comprehension_vulnerability_check_ = enabled; } - // If enable_qualified_identifier_rewrites is true, the evaluator will attempt - // to disambiguate namespace qualified identifiers. - // - // For functions, this will attempt to determine whether a function call is a - // receiver call or a namespace qualified function. - void set_enable_qualified_identifier_rewrites( - bool enable_qualified_identifier_rewrites) { - enable_qualified_identifier_rewrites_ = - enable_qualified_identifier_rewrites; + void AddAstTransform(std::unique_ptr transform) { + ast_transforms_.push_back(std::move(transform)); } void set_enable_regex_precompilation(bool enable) { @@ -101,11 +96,8 @@ class FlatExprBuilder : public CelExpressionBuilder { cel::ast::Ast& ast, std::vector* warnings) const; cel::RuntimeOptions options_; + std::vector> ast_transforms_; - int comprehension_max_iterations_ = 0; - bool enable_heterogeneous_equality_ = false; - - bool enable_qualified_identifier_rewrites_ = false; bool enable_regex_precompilation_ = false; bool enable_comprehension_vulnerability_check_ = false; bool constant_folding_ = false; diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h new file mode 100644 index 000000000..c66141fa3 --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -0,0 +1,61 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// API definitions for planner extensions. +// +// These are provided to indirect build dependencies for optional features and +// require detailed understanding of how the flat expression builder works and +// its assumptions. +// +// These interfaces should not be implemented directly by CEL users. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ + +#include "absl/status/status.h" +#include "base/ast.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/expression_build_warning.h" + +namespace google::api::expr::runtime { + +// Class representing FlatExpr internals exposed to extensions. +class PlannerContext { + public: + explicit PlannerContext(const Resolver& resolver, + BuilderWarnings& builder_warnings) + : resolver_(resolver), builder_warnings_(builder_warnings) {} + + const Resolver& resolver() const { return resolver_; } + BuilderWarnings& builder_warnings() { return builder_warnings_; } + + private: + const Resolver& resolver_; + BuilderWarnings& builder_warnings_; +}; + +// Interface for Ast Transforms. +// If any are present, the flat expr builder will apply the Ast Transforms in +// order on a copy of the relevant input expressions before planning the +// program. +class AstTransform { + public: + virtual ~AstTransform() = default; + + virtual absl::Status UpdateAst(PlannerContext& context, + cel::ast::internal::AstImpl& ast) const = 0; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index d9d937886..06bce0513 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -37,6 +37,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "eval/compiler/qualified_reference_resolver.h" #include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" @@ -689,7 +690,8 @@ TEST(FlatExprBuilderTest, InvalidContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( @@ -718,7 +720,8 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; @@ -747,7 +750,8 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; @@ -773,7 +777,8 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderParentContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; @@ -799,7 +804,8 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; @@ -824,7 +830,8 @@ TEST(FlatExprBuilderTest, TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); FlatExprBuilder builder; - builder.set_enable_qualified_identifier_rewrites(true); + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; @@ -958,6 +965,8 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { &expr); FlatExprBuilder builder; + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -1025,6 +1034,8 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { &expr); FlatExprBuilder builder; + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); builder.set_container("com.foo"); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK((FunctionAdapter::CreateAndRegister( @@ -1091,6 +1102,8 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { &expr); FlatExprBuilder builder; + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -1154,6 +1167,8 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { &expr); FlatExprBuilder builder; + builder.AddAstTransform( + NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); google::protobuf::Arena arena; builder.set_constant_folding(true, &arena); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index b11e70c22..8ff34f6b9 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -13,6 +14,7 @@ #include "absl/types/optional.h" #include "base/ast.h" #include "base/internal/ast_impl.h" +#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/const_value_step.h" #include "eval/eval/expression_build_warning.h" #include "eval/public/ast_rewrite_native.h" @@ -269,6 +271,25 @@ class ReferenceResolver : public cel::ast::internal::AstRewriterBase { absl::flat_hash_set rewritten_reference_; }; +class ReferenceResolverExtension : public AstTransform { + public: + explicit ReferenceResolverExtension(ReferenceResolverOption opt) + : opt_(opt) {} + absl::Status UpdateAst(PlannerContext& context, + cel::ast::internal::AstImpl& ast) const override { + if (opt_ == ReferenceResolverOption::kCheckedOnly && + ast.reference_map().empty()) { + return absl::OkStatus(); + } + return ResolveReferences(context.resolver(), context.builder_warnings(), + ast) + .status(); + } + + private: + ReferenceResolverOption opt_; +}; + } // namespace absl::StatusOr ResolveReferences(const Resolver& resolver, @@ -286,4 +307,9 @@ absl::StatusOr ResolveReferences(const Resolver& resolver, return was_rewritten; } +std::unique_ptr NewReferenceResolverExtension( + ReferenceResolverOption option) { + return std::make_unique(option); +} + } // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index 8851c2e97..e4205edc5 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -2,10 +2,12 @@ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ #include +#include #include "absl/status/statusor.h" #include "base/ast.h" #include "base/ast_internal.h" +#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/expression_build_warning.h" @@ -24,6 +26,16 @@ absl::StatusOr ResolveReferences(const Resolver& resolver, BuilderWarnings& warnings, cel::ast::internal::AstImpl& ast); +enum class ReferenceResolverOption { + // Always attempt to resolve references based on runtime types and functions. + kAlways, + // Only attempt to resolve for checked expressions with reference metadata. + kCheckedOnly, +}; + +std::unique_ptr NewReferenceResolverExtension( + ReferenceResolverOption option); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ diff --git a/eval/public/BUILD b/eval/public/BUILD index 38d6d2935..bfd939578 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -1210,6 +1210,7 @@ cc_library( ":cel_expression", ":cel_options", "//eval/compiler:flat_expr_builder", + "//eval/compiler:qualified_reference_resolver", "//eval/public/structs:legacy_type_provider", "//runtime:runtime_options", "@com_google_absl//absl/status", diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 9dabe9cd8..9ea9ca1c7 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -22,6 +22,7 @@ #include "absl/status/status.h" #include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/qualified_reference_resolver.h" #include "eval/public/cel_options.h" #include "runtime/runtime_options.h" @@ -40,12 +41,14 @@ std::unique_ptr CreatePortableExprBuilder( builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); + builder->AddAstTransform(NewReferenceResolverExtension( + (options.enable_qualified_identifier_rewrites) + ? ReferenceResolverOption::kAlways + : ReferenceResolverOption::kCheckedOnly)); // TODO(issues/5): These need to be abstracted to avoid bringing in too // many build dependencies by default. builder->set_enable_comprehension_vulnerability_check( options.enable_comprehension_vulnerability_check); - builder->set_enable_qualified_identifier_rewrites( - options.enable_qualified_identifier_rewrites); builder->set_enable_regex_precompilation(options.enable_regex_precompilation); builder->set_constant_folding(options.constant_folding, options.constant_arena); From 3715030b60d0bbc9095a3b2b6ed6880d51efdf77 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 27 Apr 2023 23:29:57 +0000 Subject: [PATCH 228/303] Standardize creation of values which reference other values PiperOrigin-RevId: 527716254 --- base/BUILD | 1 + base/handle.h | 39 ++ base/type.h | 6 + base/value.h | 6 + base/value_factory.h | 185 ++++-- base/values/bytes_value.h | 10 +- base/values/list_value.h | 4 +- base/values/map_value.h | 4 +- base/values/string_value.h | 10 +- base/values/struct_value.h | 4 +- extensions/protobuf/struct_value.cc | 905 +++++++++------------------- 11 files changed, 506 insertions(+), 668 deletions(-) diff --git a/base/BUILD b/base/BUILD index 96cc952c9..e2a419128 100644 --- a/base/BUILD +++ b/base/BUILD @@ -272,6 +272,7 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/base/handle.h b/base/handle.h index 728005ec4..4a846b4e2 100644 --- a/base/handle.h +++ b/base/handle.h @@ -19,6 +19,7 @@ #include #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/utility/utility.h" @@ -255,6 +256,8 @@ class Handle final : private base_internal::HandlePolicy { explicit Handle(absl::in_place_t, Args&&... args) : impl_(std::forward(args)...) {} + T* Release() { return static_cast(impl_.release()); } + Impl impl_; }; @@ -309,6 +312,42 @@ struct HandleFactory { return Handle(absl::in_place, *base_internal::ManagedMemoryRelease(managed_memory)); } + + template + static std::enable_if_t, Handle> MakeFrom( + const F* data) { + if (Metadata::IsReferenceCounted(*data)) { + Metadata::Ref(*data); + } else { + ABSL_ASSERT(Metadata::IsArenaAllocated(*data)); + } + return Handle(absl::in_place, *const_cast(data)); + } + + // Requires that T is reference counted or arena allocated. Clears Handle + // without decrementing the reference count. Returns nullptr if T is arena + // allocated, T* otherwise. + template + static std::enable_if_t, F*> Release( + Handle& handle) { + T* data = handle.Release(); + if (Metadata::IsArenaAllocated(*data)) { + data = nullptr; + } else { + ABSL_ASSERT(Metadata::IsReferenceCounted(*data)); + } + return data; + } +}; + +template +class EnableHandleFromThis { + protected: + Handle handle_from_this() const { + static_assert(std::is_base_of_v, T>); + // It is guaranteed that we are either reference counted or arena allocated. + return HandleFactory::MakeFrom(reinterpret_cast(this)); + } }; } // namespace cel::base_internal diff --git a/base/type.h b/base/type.h index b243ddadc..1df85a36d 100644 --- a/base/type.h +++ b/base/type.h @@ -172,6 +172,12 @@ class TypeHandle final { void HashValue(absl::HashState state) const; + Type* release() { + Type* type = static_cast(data_.get_heap()); + data_.set_pointer(0); + return type; + } + private: static bool Equals(const Type& lhs, const Type& rhs, Kind kind); diff --git a/base/value.h b/base/value.h index 23591ad1d..3ebde96a2 100644 --- a/base/value.h +++ b/base/value.h @@ -166,6 +166,12 @@ class ValueHandle final { bool Equals(const ValueHandle& other) const; + Value* release() { + Value* value = static_cast(data_.get_heap()); + data_.set_pointer(0); + return value; + } + private: friend class ValueMetadata; diff --git a/base/value_factory.h b/base/value_factory.h index bb3363006..7685ad872 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -21,6 +21,7 @@ #include #include "absl/base/attributes.h" +#include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" @@ -54,15 +55,44 @@ namespace cel { namespace base_internal { -struct ValueFactoryAccess; + +template +class ReferentValue final : public T { + public: + template + explicit ReferentValue(const cel::Value* referent, Args&&... args) + : T(std::forward(args)...), + referent_(ABSL_DIE_IF_NULL(referent)) // Crash OK + {} + + ~ReferentValue() override { ValueMetadata::Unref(*referent_); } + + private: + const cel::Value* const referent_; +}; + } // namespace base_internal class ValueFactory final { private: template - using EnableIfBaseOfT = + using EnableIfBaseOf = std::enable_if_t>, V>; + template + using EnableIfReferent = std::enable_if_t< + std::conjunction_v, + std::is_base_of>, + V>; + + template + using EnableIfBaseOfAndReferent = std::enable_if_t< + std::conjunction_v< + std::is_base_of>, + std::is_base_of>, + std::is_base_of>>, + V>; + public: explicit ValueFactory(TypeManager& type_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) : type_manager_(type_manager) {} @@ -131,6 +161,20 @@ class ValueFactory final { absl::MakeCordFromExternal(value, std::forward(releaser))); } + template + EnableIfReferent>> + CreateReferentBytesValue(Handle reference, absl::string_view value) { + if (value.empty()) { + return GetEmptyBytesValue(); + } + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return base_internal::HandleFactory::Make< + base_internal::InlinedStringViewBytesValue>(value); + } + return CreateMemberBytesValue(value, pointer); + } + Handle GetStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetEmptyStringValue(); } @@ -169,6 +213,20 @@ class ValueFactory final { absl::MakeCordFromExternal(value, std::forward(releaser))); } + template + EnableIfReferent>> + CreateReferentStringValue(Handle reference, absl::string_view value) { + if (value.empty()) { + return GetEmptyStringValue(); + } + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return base_internal::HandleFactory::Make< + base_internal::InlinedStringViewStringValue>(value); + } + return CreateMemberStringValue(value, pointer); + } + absl::StatusOr> CreateDurationValue( absl::Duration value) ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -219,25 +277,51 @@ class ValueFactory final { } template - EnableIfBaseOfT>> CreateStructValue( - const Handle& struct_type, + EnableIfBaseOf>> CreateStructValue( + const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< - std::remove_const_t>(memory_manager(), struct_type, + std::remove_const_t>(memory_manager(), type, std::forward(args)...); } template - EnableIfBaseOfT>> CreateStructValue( - Handle&& struct_type, - Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + EnableIfBaseOf>> CreateStructValue( + Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< - std::remove_const_t>(memory_manager(), std::move(struct_type), + std::remove_const_t>(memory_manager(), std::move(type), std::forward(args)...); } + template + EnableIfBaseOfAndReferent>> + CreateReferentStructValue(Handle reference, const Handle& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateStructValue(type, std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, type, + std::forward(args)...); + } + + template + EnableIfBaseOfAndReferent>> + CreateReferentStructValue(Handle reference, Handle&& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateStructValue(std::move(type), std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, + std::move(type), + std::forward(args)...); + } + template - EnableIfBaseOfT>> CreateListValue( + EnableIfBaseOf>> CreateListValue( const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< @@ -246,15 +330,42 @@ class ValueFactory final { } template - EnableIfBaseOfT>> CreateListValue( + EnableIfBaseOf>> CreateListValue( Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< std::remove_const_t>(memory_manager(), std::move(type), std::forward(args)...); } + template + EnableIfBaseOfAndReferent>> + CreateReferentListValue(Handle reference, const Handle& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateListValue(type, std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, type, + std::forward(args)...); + } + + template + EnableIfBaseOfAndReferent>> + CreateReferentListValue(Handle reference, Handle&& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateListValue(std::move(type), std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, + std::move(type), + std::forward(args)...); + } + template - EnableIfBaseOfT>> CreateMapValue( + EnableIfBaseOf>> CreateMapValue( const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< @@ -263,13 +374,40 @@ class ValueFactory final { } template - EnableIfBaseOfT>> CreateMapValue( + EnableIfBaseOf>> CreateMapValue( Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::HandleFactory::template Make< std::remove_const_t>(memory_manager(), std::move(type), std::forward(args)...); } + template + EnableIfBaseOfAndReferent>> + CreateReferentMapValue(Handle reference, const Handle& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateMapValue(type, std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, type, + std::forward(args)...); + } + + template + EnableIfBaseOfAndReferent>> + CreateReferentMapValue(Handle reference, Handle&& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + auto* pointer = base_internal::HandleFactory::Release(reference); + if (pointer == nullptr) { + return CreateMapValue(std::move(type), std::forward(args)...); + } + return base_internal::HandleFactory::template Make< + base_internal::ReferentValue>(memory_manager(), pointer, + std::move(type), + std::forward(args)...); + } + Handle CreateTypeValue(const Handle& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -298,7 +436,6 @@ class ValueFactory final { private: friend class BytesValue; friend class StringValue; - friend struct base_internal::ValueFactoryAccess; Handle GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -313,9 +450,6 @@ class ValueFactory final { absl::StatusOr> CreateMemberBytesValue( absl::string_view value, const Value* owner) ABSL_ATTRIBUTE_LIFETIME_BOUND { - if (value.empty()) { - return GetEmptyBytesValue(); - } return base_internal::HandleFactory::template Make< base_internal::InlinedStringViewBytesValue>(value, owner); } @@ -323,9 +457,6 @@ class ValueFactory final { absl::StatusOr> CreateMemberStringValue( absl::string_view value, const Value* owner) ABSL_ATTRIBUTE_LIFETIME_BOUND { - if (value.empty()) { - return GetEmptyStringValue(); - } return base_internal::HandleFactory::template Make< base_internal::InlinedStringViewStringValue>(value, owner); } @@ -368,20 +499,6 @@ inline Handle ValueTraits::Wrap( return value_factory.CreateUncheckedTimestampValue(value); } -struct ValueFactoryAccess { - static absl::StatusOr> CreateMemberBytesValue( - ValueFactory& value_factory, absl::string_view value, - const Value* owner) { - return value_factory.CreateMemberBytesValue(value, owner); - } - - static absl::StatusOr> CreateMemberStringValue( - ValueFactory& value_factory, absl::string_view value, - const Value* owner) { - return value_factory.CreateMemberStringValue(value, owner); - } -}; - } // namespace base_internal } // namespace cel diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index 1b12b53c1..728e56ba1 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -156,20 +156,14 @@ class InlinedStringViewBytesValue final : public BytesValue, public InlineData { // by `owner`. `owner` may be nullptr, in which case `value` has no owner and // must live for the duration of the underlying `MemoryManager`. InlinedStringViewBytesValue(absl::string_view value, const Value* owner) - : InlinedStringViewBytesValue( - value, owner, - owner == nullptr || !Metadata::IsArenaAllocated(*owner)) {} + : InlinedStringViewBytesValue(value, owner, owner == nullptr) {} InlinedStringViewBytesValue(absl::string_view value, const Value* owner, bool trivial) : InlineData(kMetadata | (trivial ? kTrivial : uintptr_t{0}) | AsInlineVariant(InlinedBytesValueVariant::kStringView)), value_(value), - owner_(trivial ? nullptr : owner) { - if (owner_ != nullptr) { - Metadata::Ref(*owner_); - } - } + owner_(trivial ? nullptr : owner) {} // Only called when owner_ was, at some point, not nullptr. InlinedStringViewBytesValue(const InlinedStringViewBytesValue& other) diff --git a/base/values/list_value.h b/base/values/list_value.h index feee136d4..a14b5a0ce 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -158,7 +158,9 @@ class LegacyListValue final : public ListValue, public InlineData { uintptr_t impl_; }; -class AbstractListValue : public ListValue, public HeapData { +class AbstractListValue : public ListValue, + public HeapData, + public EnableHandleFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/base/values/map_value.h b/base/values/map_value.h index 823230f77..4c46be9df 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -185,7 +185,9 @@ class LegacyMapValue final : public MapValue, public InlineData { uintptr_t impl_; }; -class AbstractMapValue : public MapValue, public HeapData { +class AbstractMapValue : public MapValue, + public HeapData, + public EnableHandleFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/base/values/string_value.h b/base/values/string_value.h index d19fe934e..058dad924 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -172,20 +172,14 @@ class InlinedStringViewStringValue final : public StringValue, // by `owner`. `owner` may be nullptr, in which case `value` has no owner and // must live for the duration of the underlying `MemoryManager`. InlinedStringViewStringValue(absl::string_view value, const Value* owner) - : InlinedStringViewStringValue( - value, owner, - owner == nullptr || !Metadata::IsArenaAllocated(*owner)) {} + : InlinedStringViewStringValue(value, owner, owner == nullptr) {} InlinedStringViewStringValue(absl::string_view value, const Value* owner, bool trivial) : InlineData(kMetadata | (trivial ? kTrivial : uintptr_t{0}) | AsInlineVariant(InlinedStringValueVariant::kStringView)), value_(value), - owner_(trivial ? nullptr : owner) { - if (owner_ != nullptr) { - Metadata::Ref(*owner_); - } - } + owner_(trivial ? nullptr : owner) {} // Only called when owner_ was, at some point, not nullptr. InlinedStringViewStringValue(const InlinedStringViewStringValue& other) diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 6f7d98ef7..7d0186147 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -227,7 +227,9 @@ class LegacyStructValue final : public StructValue, public InlineData { uintptr_t type_info_; }; -class AbstractStructValue : public StructValue, public HeapData { +class AbstractStructValue : public StructValue, + public HeapData, + public EnableHandleFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index a92912274..e7cb7c02a 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -103,23 +103,6 @@ Handle CreateBytesValueFromView(absl::string_view value) { base_internal::InlinedStringViewBytesValue>(value); } -struct CreateStringValueFromStringWrapperVisitor final { - ValueFactory& value_factory; - const Value* owner; - - absl::StatusOr> operator()(absl::string_view value) const { - if (cel::base_internal::Metadata::IsReferenceCounted(*owner)) { - return base_internal::ValueFactoryAccess::CreateMemberStringValue( - value_factory, value, owner); - } - return CreateStringValueFromView(value); - } - - absl::StatusOr> operator()(absl::Cord value) const { - return value_factory.CreateStringValue(std::move(value)); - } -}; - struct DebugStringFromStringWrapperVisitor final { std::string operator()(absl::string_view value) const { return StringValue::DebugString(value); @@ -144,18 +127,13 @@ class HeapDynamicParsedProtoStructValue final class DynamicMemberParsedProtoStructValue : public ParsedProtoStructValue { public: - static absl::StatusOr> Create( - ValueFactory& value_factory, Handle type, const Value* parent, - const google::protobuf::Message* value); - - const google::protobuf::Message& value() const final { return *value_; } - - protected: DynamicMemberParsedProtoStructValue(Handle type, const google::protobuf::Message* value) : ParsedProtoStructValue(std::move(type)), value_(ABSL_DIE_IF_NULL(value)) {} // Crash OK + const google::protobuf::Message& value() const final { return *value_; } + absl::optional ValueReference( google::protobuf::Message& scratch, const google::protobuf::Descriptor& desc, internal::TypeInfo type) const final { @@ -169,49 +147,6 @@ class DynamicMemberParsedProtoStructValue : public ParsedProtoStructValue { const google::protobuf::Message* const value_; }; -class ArenaDynamicMemberParsedProtoStructValue final - : public DynamicMemberParsedProtoStructValue { - public: - ArenaDynamicMemberParsedProtoStructValue(Handle type, - const google::protobuf::Message* value) - : DynamicMemberParsedProtoStructValue(std::move(type), value) {} -}; - -class ReffedDynamicMemberParsedProtoStructValue final - : public DynamicMemberParsedProtoStructValue { - public: - ReffedDynamicMemberParsedProtoStructValue(Handle type, - const Value* parent, - const google::protobuf::Message* value) - : DynamicMemberParsedProtoStructValue(std::move(type), value), - parent_(parent) { - base_internal::Metadata::Ref(*parent_); - } - - ~ReffedDynamicMemberParsedProtoStructValue() override { - base_internal::ValueMetadata::Unref(*parent_); - } - - private: - const Value* const parent_; -}; - -absl::StatusOr> -DynamicMemberParsedProtoStructValue::Create(ValueFactory& value_factory, - Handle type, - const Value* parent, - const google::protobuf::Message* value) { - if (parent != nullptr && - base_internal::Metadata::IsReferenceCounted(*parent)) { - return value_factory - .CreateStructValue( - std::move(type), parent, value); - } - return value_factory - .CreateStructValue( - std::move(type), value); -} - } // namespace } // namespace protobuf_internal @@ -822,9 +757,10 @@ class ParsedProtoListValue if (&field != scratch.get()) { // Scratch was not used, we can avoid copying. scratch.reset(); - return protobuf_internal::DynamicMemberParsedProtoStructValue::Create( - context.value_factory(), type()->element().As(), this, - &field); + return context.value_factory() + .CreateReferentStructValue< + protobuf_internal::DynamicMemberParsedProtoStructValue>( + handle_from_this(), type()->element().As(), &field); } if (ProtoMemoryManager::Is(context.value_factory().memory_manager())) { auto* arena = ProtoMemoryManager::CastToProtoArena( @@ -1077,10 +1013,14 @@ class ParsedProtoListValue const auto& field = fields_.Get(static_cast(index), scratch.get()); CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(field)); - return absl::visit( - protobuf_internal::CreateStringValueFromStringWrapperVisitor{ - context.value_factory(), this}, - std::move(wrapped)); + if (absl::holds_alternative(wrapped)) { + return context.value_factory().CreateReferentStringValue( + handle_from_this(), absl::get(wrapped)); + } else { + ABSL_ASSERT(absl::holds_alternative(wrapped)); + return context.value_factory().CreateStringValue( + absl::get(std::move(wrapped))); + } } private: @@ -1138,31 +1078,6 @@ class ParsedProtoListValue const google::protobuf::RepeatedFieldRef fields_; }; -template -class ArenaParsedProtoListValue final : public ParsedProtoListValue { - public: - using ParsedProtoListValue::ParsedProtoListValue; -}; - -template -class ReffedParsedProtoListValue final : public ParsedProtoListValue { - public: - ReffedParsedProtoListValue(Handle type, - google::protobuf::RepeatedFieldRef

fields, - const Value* owner) - : ParsedProtoListValue(std::move(type), std::move(fields)), - owner_(owner) { - cel::base_internal::ValueMetadata::Ref(*owner_); - } - - ~ReffedParsedProtoListValue() override { - cel::base_internal::ValueMetadata::Unref(*owner_); - } - - private: - const Value* owner_; -}; - void ProtoDebugStringEnum(std::string& out, const google::protobuf::EnumDescriptor& desc, int32_t value) { if (desc.full_name() == "google.protobuf.NullValue") { @@ -1290,34 +1205,6 @@ void ProtoDebugStringMap(std::string& out, const google::protobuf::Message& mess out.push_back('}'); } -absl::StatusOr> FromProtoMapKey(ValueFactory& value_factory, - const google::protobuf::MapKey& key, - const Value* owner) { - switch (key.type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: - return value_factory.CreateIntValue(key.GetInt64Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: - return value_factory.CreateIntValue(key.GetInt32Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: - return value_factory.CreateUintValue(key.GetUInt64Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: - return value_factory.CreateUintValue(key.GetUInt32Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: - - if (cel::base_internal::Metadata::IsReferenceCounted(*owner)) { - return cel::base_internal::ValueFactoryAccess::CreateMemberStringValue( - value_factory, key.GetStringValue(), owner); - } - return protobuf_internal::CreateStringValueFromView(key.GetStringValue()); - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: - return value_factory.CreateBoolValue(key.GetBoolValue()); - default: - // Unreachable because protobuf is extremely unlikely to introduce - // additional supported key types. - ABSL_UNREACHABLE(); - } -} - // Transform Value into MapKey. Requires that value is compatible with protocol // buffer map key. bool ToProtoMapKey(google::protobuf::MapKey& key, const Handle& value, @@ -1372,142 +1259,6 @@ bool ToProtoMapKey(google::protobuf::MapKey& key, const Handle& value, return true; } -absl::StatusOr> FromProtoMapValue( - ValueFactory& value_factory, const google::protobuf::MapValueConstRef& value, - const google::protobuf::FieldDescriptor& field, const Value* owner) { - const auto* value_desc = field.message_type()->map_value(); - switch (value_desc->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: - return value_factory.CreateBoolValue(value.GetBoolValue()); - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: - return value_factory.CreateIntValue(value.GetInt64Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: - return value_factory.CreateIntValue(value.GetInt32Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: - return value_factory.CreateUintValue(value.GetUInt64Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: - return value_factory.CreateUintValue(value.GetUInt32Value()); - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: - return value_factory.CreateDoubleValue(value.GetFloatValue()); - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: - return value_factory.CreateDoubleValue(value.GetDoubleValue()); - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - if (value_desc->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { - if (cel::base_internal::Metadata::IsReferenceCounted(*owner)) { - return cel::base_internal::ValueFactoryAccess::CreateMemberBytesValue( - value_factory, value.GetStringValue(), owner); - } - return protobuf_internal::CreateBytesValueFromView( - value.GetStringValue()); - } else { - if (cel::base_internal::Metadata::IsReferenceCounted(*owner)) { - return cel::base_internal::ValueFactoryAccess:: - CreateMemberStringValue(value_factory, value.GetStringValue(), - owner); - } - return protobuf_internal::CreateStringValueFromView( - value.GetStringValue()); - } - } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - CEL_ASSIGN_OR_RETURN(auto type, - ProtoType::Resolve(value_factory.type_manager(), - *value_desc->enum_type())); - switch (type->kind()) { - case Kind::kNullType: - return value_factory.GetNullValue(); - case Kind::kEnum: - return value_factory.CreateEnumValue( - std::move(type).As(), value.GetEnumValue()); - default: - return absl::InternalError(absl::StrCat( - "Unexpected protocol buffer type implementation for \"", - value_desc->message_type()->full_name(), - "\": ", type->DebugString())); - } - } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - CEL_ASSIGN_OR_RETURN(auto type, - ProtoType::Resolve(value_factory.type_manager(), - *value_desc->message_type())); - switch (type->kind()) { - case Kind::kDuration: { - CEL_ASSIGN_OR_RETURN(auto duration, - protobuf_internal::AbslDurationFromDurationProto( - value.GetMessageValue())); - return value_factory.CreateUncheckedDurationValue(duration); - } - case Kind::kTimestamp: { - CEL_ASSIGN_OR_RETURN(auto time, - protobuf_internal::AbslTimeFromTimestampProto( - value.GetMessageValue())); - return value_factory.CreateUncheckedTimestampValue(time); - } - case Kind::kBool: { - // google.protobuf.BoolValue, mapped to CEL primitive bool type for - // map values. - CEL_ASSIGN_OR_RETURN( - auto wrapped, - protobuf_internal::UnwrapBoolValueProto(value.GetMessageValue())); - return value_factory.CreateBoolValue(wrapped); - } - case Kind::kBytes: { - // google.protobuf.BytesValue, mapped to CEL primitive bytes type for - // map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBytesValueProto( - value.GetMessageValue())); - return value_factory.CreateBytesValue(std::move(wrapped)); - } - case Kind::kDouble: { - // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive - // double type for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapDoubleValueProto( - value.GetMessageValue())); - return value_factory.CreateDoubleValue(wrapped); - } - case Kind::kInt: { - // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive - // int type for map values. - CEL_ASSIGN_OR_RETURN( - auto wrapped, - protobuf_internal::UnwrapIntValueProto(value.GetMessageValue())); - return value_factory.CreateIntValue(wrapped); - } - case Kind::kString: { - // google.protobuf.StringValue, mapped to CEL primitive bytes type for - // map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapStringValueProto( - value.GetMessageValue())); - return absl::visit( - protobuf_internal::CreateStringValueFromStringWrapperVisitor{ - value_factory, owner}, - std::move(wrapped)); - } - case Kind::kUint: { - // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL primitive - // uint type for map values. - CEL_ASSIGN_OR_RETURN( - auto wrapped, - protobuf_internal::UnwrapUIntValueProto(value.GetMessageValue())); - return value_factory.CreateUintValue(wrapped); - } - case Kind::kStruct: - return protobuf_internal::DynamicMemberParsedProtoStructValue::Create( - value_factory, std::move(type).As(), owner, - &value.GetMessageValue()); - default: - return absl::InternalError(absl::StrCat( - "Unexpected protocol buffer type implementation for \"", - value_desc->message_type()->full_name(), - "\": ", type->DebugString())); - } - } - } -} - class ParsedProtoMapValueKeysList : public CEL_LIST_VALUE_CLASS { public: ParsedProtoMapValueKeysList( @@ -1535,7 +1286,26 @@ class ParsedProtoMapValueKeysList : public CEL_LIST_VALUE_CLASS { absl::StatusOr> Get(const GetContext& context, size_t index) const final { - return FromProtoMapKey(context.value_factory(), keys_[index], this); + const auto& key = keys_[index]; + switch (key.type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return context.value_factory().CreateIntValue(key.GetInt64Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return context.value_factory().CreateIntValue(key.GetInt32Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return context.value_factory().CreateUintValue(key.GetUInt64Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return context.value_factory().CreateUintValue(key.GetUInt32Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return context.value_factory().CreateReferentStringValue( + handle_from_this(), key.GetStringValue()); + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return context.value_factory().CreateBoolValue(key.GetBoolValue()); + default: + // Unreachable because protobuf is extremely unlikely to introduce + // additional supported key types. + ABSL_UNREACHABLE(); + } } private: @@ -1546,32 +1316,6 @@ class ParsedProtoMapValueKeysList : public CEL_LIST_VALUE_CLASS { const std::vector> keys_; }; -class ArenaParsedProtoMapValueKeysList final - : public ParsedProtoMapValueKeysList { - public: - using ParsedProtoMapValueKeysList::ParsedProtoMapValueKeysList; -}; - -class ReffedParsedProtoMapValueKeysList final - : public ParsedProtoMapValueKeysList { - public: - ReffedParsedProtoMapValueKeysList( - Handle type, - std::vector> keys, - const Value* owner) - : ParsedProtoMapValueKeysList(std::move(type), std::move(keys)), - owner_(owner) { - cel::base_internal::ValueMetadata::Ref(*owner_); - } - - ~ReffedParsedProtoMapValueKeysList() override { - cel::base_internal::ValueMetadata::Unref(*owner_); - } - - private: - const Value* const owner_; -}; - class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { public: ParsedProtoMapValue(Handle type, const google::protobuf::Message& message, @@ -1607,8 +1351,145 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_key, &proto_value)) { return absl::nullopt; } - return FromProtoMapValue(context.value_factory(), proto_value, field_, - this); + const auto* value_desc = field_.message_type()->map_value(); + switch (value_desc->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return context.value_factory().CreateBoolValue( + proto_value.GetBoolValue()); + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return context.value_factory().CreateIntValue( + proto_value.GetInt64Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return context.value_factory().CreateIntValue( + proto_value.GetInt32Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return context.value_factory().CreateUintValue( + proto_value.GetUInt64Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return context.value_factory().CreateUintValue( + proto_value.GetUInt32Value()); + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return context.value_factory().CreateDoubleValue( + proto_value.GetFloatValue()); + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return context.value_factory().CreateDoubleValue( + proto_value.GetDoubleValue()); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (value_desc->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + return context.value_factory().CreateReferentBytesValue( + handle_from_this(), proto_value.GetStringValue()); + } else { + return context.value_factory().CreateReferentStringValue( + handle_from_this(), proto_value.GetStringValue()); + } + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + CEL_ASSIGN_OR_RETURN( + auto type, + ProtoType::Resolve(context.value_factory().type_manager(), + *value_desc->enum_type())); + switch (type->kind()) { + case Kind::kNullType: + return context.value_factory().GetNullValue(); + case Kind::kEnum: + return context.value_factory().CreateEnumValue( + std::move(type).As(), + proto_value.GetEnumValue()); + default: + return absl::InternalError(absl::StrCat( + "Unexpected protocol buffer type implementation for \"", + value_desc->message_type()->full_name(), + "\": ", type->DebugString())); + } + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + CEL_ASSIGN_OR_RETURN( + auto type, + ProtoType::Resolve(context.value_factory().type_manager(), + *value_desc->message_type())); + switch (type->kind()) { + case Kind::kDuration: { + CEL_ASSIGN_OR_RETURN( + auto duration, protobuf_internal::AbslDurationFromDurationProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateUncheckedDurationValue( + duration); + } + case Kind::kTimestamp: { + CEL_ASSIGN_OR_RETURN(auto time, + protobuf_internal::AbslTimeFromTimestampProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateUncheckedTimestampValue(time); + } + case Kind::kBool: { + // google.protobuf.BoolValue, mapped to CEL primitive bool type for + // map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBoolValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateBoolValue(wrapped); + } + case Kind::kBytes: { + // google.protobuf.BytesValue, mapped to CEL primitive bytes type + // for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBytesValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateBytesValue(std::move(wrapped)); + } + case Kind::kDouble: { + // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive + // double type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapDoubleValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateDoubleValue(wrapped); + } + case Kind::kInt: { + // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive + // int type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapIntValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateIntValue(wrapped); + } + case Kind::kString: { + // google.protobuf.StringValue, mapped to CEL primitive bytes type + // for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapStringValueProto( + proto_value.GetMessageValue())); + if (absl::holds_alternative(wrapped)) { + return context.value_factory().CreateReferentStringValue( + handle_from_this(), absl::get(wrapped)); + } else { + ABSL_ASSERT(absl::holds_alternative(wrapped)); + return context.value_factory().CreateStringValue( + absl::get(std::move(wrapped))); + } + } + case Kind::kUint: { + // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL + // primitive uint type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUIntValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateUintValue(wrapped); + } + case Kind::kStruct: + return context.value_factory() + .CreateReferentStructValue< + protobuf_internal::DynamicMemberParsedProtoStructValue>( + handle_from_this(), std::move(type).As(), + &proto_value.GetMessageValue()); + default: + return absl::InternalError(absl::StrCat( + "Unexpected protocol buffer type implementation for \"", + value_desc->message_type()->full_name(), + "\": ", type->DebugString())); + } + } + } } absl::StatusOr Has(const HasContext& context, @@ -1640,14 +1521,9 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { for (; begin != end; ++begin) { keys.push_back(begin.GetKey()); } - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue( - std::move(list_type), std::move(keys), this); - } return context.value_factory() - .CreateListValue(std::move(list_type), - std::move(keys)); + .CreateReferentListValue( + handle_from_this(), std::move(list_type), std::move(keys)); } private: @@ -1663,29 +1539,6 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { const google::protobuf::FieldDescriptor& field_; }; -class ArenaParsedProtoMapValue final : public ParsedProtoMapValue { - public: - using ParsedProtoMapValue::ParsedProtoMapValue; -}; - -class ReffedParsedProtoMapValue final : public ParsedProtoMapValue { - public: - ReffedParsedProtoMapValue(Handle type, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field, - const Value* owner) - : ParsedProtoMapValue(std::move(type), message, field), owner_(owner) { - cel::base_internal::ValueMetadata::Ref(*owner_); - } - - ~ReffedParsedProtoMapValue() override { - cel::base_internal::ValueMetadata::Unref(*owner_); - } - - private: - const Value* owner_; -}; - void ProtoDebugStringSingular(std::string& out, const google::protobuf::Message& message, const google::protobuf::Reflection* reflect, const google::protobuf::FieldDescriptor* field_desc) { @@ -2175,13 +2028,8 @@ absl::StatusOr> ParsedProtoStructValue::GetMapField( const GetFieldContext& context, const StructType::Field& field, const google::protobuf::Reflection& reflect, const google::protobuf::FieldDescriptor& field_desc) const { - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory().CreateMapValue( - field.type.As(), value(), field_desc, this); - } else { - return context.value_factory().CreateMapValue( - field.type.As(), value(), field_desc); - } + return context.value_factory().CreateReferentMapValue( + handle_from_this(), field.type.As(), value(), field_desc); } absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( @@ -2190,308 +2038,146 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( const google::protobuf::FieldDescriptor& field_desc) const { switch (field_desc.type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_FLOAT: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_INT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_INT32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_UINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_FIXED64: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_BOOL: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_STRING: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: switch (field.type.As()->element()->kind()) { case Kind::kDuration: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kTimestamp: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kBool: // google.protobuf.BoolValue, mapped to CEL primitive bool type for // list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kBytes: // google.protobuf.BytesValue, mapped to CEL primitive bytes type for // list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kDouble: // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive // double type for list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kInt: // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive // int type for list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kString: // google.protobuf.StringValue, mapped to CEL primitive bytes type for // list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kUint: // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL primitive // uint type for list elements. - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kStruct: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), - &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); default: ABSL_UNREACHABLE(); } case google::protobuf::FieldDescriptor::TYPE_BYTES: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue< - ArenaParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type.As()->element()->kind()) { case Kind::kNullType: @@ -2501,19 +2187,11 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( reflect.GetRepeatedFieldRef(value(), &field_desc) .size()); case Kind::kEnum: - if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return context.value_factory() - .CreateListValue< - ReffedParsedProtoListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc), - this); - } else { - return context.value_factory() - .CreateListValue>( - field.type.As(), - reflect.GetRepeatedFieldRef(value(), &field_desc)); - } + return context.value_factory() + .CreateReferentListValue< + ParsedProtoListValue>( + handle_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), &field_desc)); default: ABSL_UNREACHABLE(); } @@ -2563,13 +2241,9 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( !field_desc.is_extension()) { return context.value_factory().CreateStringValue( reflect.GetCord(value(), &field_desc)); - } else if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return base_internal::ValueFactoryAccess::CreateMemberStringValue( - context.value_factory(), - reflect.GetStringView(value(), &field_desc), this); } else { - return CreateStringValueFromView( - reflect.GetStringView(value(), &field_desc)); + return context.value_factory().CreateReferentStringValue( + handle_from_this(), reflect.GetStringView(value(), &field_desc)); } case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; @@ -2630,10 +2304,14 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( auto wrapped, protobuf_internal::UnwrapStringValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); - return absl::visit( - CreateStringValueFromStringWrapperVisitor{ - context.value_factory(), this}, - std::move(wrapped)); + if (absl::holds_alternative(wrapped)) { + return context.value_factory().CreateReferentStringValue( + handle_from_this(), absl::get(wrapped)); + } else { + ABSL_ASSERT(absl::holds_alternative(wrapped)); + return context.value_factory().CreateStringValue( + absl::get(std::move(wrapped))); + } } case Kind::kUint: { CEL_ASSIGN_OR_RETURN( @@ -2648,9 +2326,10 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( } } case Kind::kStruct: - return DynamicMemberParsedProtoStructValue::Create( - context.value_factory(), field.type.As(), this, - &(reflect.GetMessage(value(), &field_desc))); + return context.value_factory() + .CreateReferentStructValue( + handle_from_this(), field.type.As(), + &(reflect.GetMessage(value(), &field_desc))); default: ABSL_UNREACHABLE(); } @@ -2659,13 +2338,9 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( !field_desc.is_extension()) { return context.value_factory().CreateBytesValue( reflect.GetCord(value(), &field_desc)); - } else if (cel::base_internal::Metadata::IsReferenceCounted(*this)) { - return base_internal::ValueFactoryAccess::CreateMemberBytesValue( - context.value_factory(), - reflect.GetStringView(value(), &field_desc), this); } else { - return CreateBytesValueFromView( - reflect.GetStringView(value(), &field_desc)); + return context.value_factory().CreateReferentBytesValue( + handle_from_this(), reflect.GetStringView(value(), &field_desc)); } case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type->kind()) { From 276fbd9ac9cdc2f39516e8fb1d05f305ed22a44c Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 28 Apr 2023 16:49:48 +0000 Subject: [PATCH 229/303] Move definitions for function name constants to /base. PiperOrigin-RevId: 527911359 --- base/BUILD | 5 ++ base/builtins.h | 104 +++++++++++++++++++++++++++++++++++++ eval/public/BUILD | 3 ++ eval/public/cel_builtins.h | 85 ++---------------------------- 4 files changed, 116 insertions(+), 81 deletions(-) create mode 100644 base/builtins.h diff --git a/base/BUILD b/base/BUILD index e2a419128..578288651 100644 --- a/base/BUILD +++ b/base/BUILD @@ -429,3 +429,8 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "builtins", + hdrs = ["builtins.h"], +) diff --git a/base/builtins.h b/base/builtins.h new file mode 100644 index 000000000..ec4026994 --- /dev/null +++ b/base/builtins.h @@ -0,0 +1,104 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ + +namespace cel { + +// Constants specifying names for CEL builtins. +namespace builtin { + +// Comparison +constexpr char kEqual[] = "_==_"; +constexpr char kInequal[] = "_!=_"; +constexpr char kLess[] = "_<_"; +constexpr char kLessOrEqual[] = "_<=_"; +constexpr char kGreater[] = "_>_"; +constexpr char kGreaterOrEqual[] = "_>=_"; + +// Logical +constexpr char kAnd[] = "_&&_"; +constexpr char kOr[] = "_||_"; +constexpr char kNot[] = "!_"; + +// Strictness +constexpr char kNotStrictlyFalse[] = "@not_strictly_false"; +// Deprecated '__not_strictly_false__' function. Preserved for backwards +// compatibility with stored expressions. +constexpr char kNotStrictlyFalseDeprecated[] = "__not_strictly_false__"; + +// Arithmetical +constexpr char kAdd[] = "_+_"; +constexpr char kSubtract[] = "_-_"; +constexpr char kNeg[] = "-_"; +constexpr char kMultiply[] = "_*_"; +constexpr char kDivide[] = "_/_"; +constexpr char kModulo[] = "_%_"; + +// String operations +constexpr char kRegexMatch[] = "matches"; +constexpr char kStringContains[] = "contains"; +constexpr char kStringEndsWith[] = "endsWith"; +constexpr char kStringStartsWith[] = "startsWith"; + +// Container operations +constexpr char kIn[] = "@in"; +// Deprecated '_in_' operator. Preserved for backwards compatibility with stored +// expressions. +constexpr char kInDeprecated[] = "_in_"; +// Deprecated 'in()' function. Preserved for backwards compatibility with stored +// expressions. +constexpr char kInFunction[] = "in"; +constexpr char kIndex[] = "_[_]"; +constexpr char kSize[] = "size"; + +constexpr char kTernary[] = "_?_:_"; + +// Timestamp and Duration +constexpr char kDuration[] = "duration"; +constexpr char kTimestamp[] = "timestamp"; +constexpr char kFullYear[] = "getFullYear"; +constexpr char kMonth[] = "getMonth"; +constexpr char kDayOfYear[] = "getDayOfYear"; +constexpr char kDayOfMonth[] = "getDayOfMonth"; +constexpr char kDate[] = "getDate"; +constexpr char kDayOfWeek[] = "getDayOfWeek"; +constexpr char kHours[] = "getHours"; +constexpr char kMinutes[] = "getMinutes"; +constexpr char kSeconds[] = "getSeconds"; +constexpr char kMilliseconds[] = "getMilliseconds"; + +// Type conversions +// TODO(issues/23): Add other type conversion methods. +constexpr char kBytes[] = "bytes"; +constexpr char kDouble[] = "double"; +constexpr char kDyn[] = "dyn"; +constexpr char kInt[] = "int"; +constexpr char kString[] = "string"; +constexpr char kType[] = "type"; +constexpr char kUint[] = "uint"; + +// Runtime-only functions. +// The convention for runtime-only functions where only the runtime needs to +// differentiate behavior is to prefix the function with `#`. +// Note, this is a different convention from CEL internal functions where the +// whole stack needs to be aware of the function id. +constexpr char kRuntimeListAppend[] = "#list_append"; + +} // namespace builtin + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ diff --git a/eval/public/BUILD b/eval/public/BUILD index bfd939578..91acad164 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -268,6 +268,9 @@ cc_library( hdrs = [ "cel_builtins.h", ], + deps = [ + "//base:builtins", + ], ) cc_library( diff --git a/eval/public/cel_builtins.h b/eval/public/cel_builtins.h index 16c172ef4..f03e02f8c 100644 --- a/eval/public/cel_builtins.h +++ b/eval/public/cel_builtins.h @@ -1,92 +1,15 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ +#include "base/builtins.h" + namespace google { namespace api { namespace expr { namespace runtime { -// Constants specifying names for CEL builtins. -namespace builtin { - -// Comparison -constexpr char kEqual[] = "_==_"; -constexpr char kInequal[] = "_!=_"; -constexpr char kLess[] = "_<_"; -constexpr char kLessOrEqual[] = "_<=_"; -constexpr char kGreater[] = "_>_"; -constexpr char kGreaterOrEqual[] = "_>=_"; - -// Logical -constexpr char kAnd[] = "_&&_"; -constexpr char kOr[] = "_||_"; -constexpr char kNot[] = "!_"; - -// Strictness -constexpr char kNotStrictlyFalse[] = "@not_strictly_false"; -// Deprecated '__not_strictly_false__' function. Preserved for backwards -// compatibility with stored expressions. -constexpr char kNotStrictlyFalseDeprecated[] = "__not_strictly_false__"; - -// Arithmetical -constexpr char kAdd[] = "_+_"; -constexpr char kSubtract[] = "_-_"; -constexpr char kNeg[] = "-_"; -constexpr char kMultiply[] = "_*_"; -constexpr char kDivide[] = "_/_"; -constexpr char kModulo[] = "_%_"; - -// String operations -constexpr char kRegexMatch[] = "matches"; -constexpr char kStringContains[] = "contains"; -constexpr char kStringEndsWith[] = "endsWith"; -constexpr char kStringStartsWith[] = "startsWith"; - -// Container operations -constexpr char kIn[] = "@in"; -// Deprecated '_in_' operator. Preserved for backwards compatibility with stored -// expressions. -constexpr char kInDeprecated[] = "_in_"; -// Deprecated 'in()' function. Preserved for backwards compatibility with stored -// expressions. -constexpr char kInFunction[] = "in"; -constexpr char kIndex[] = "_[_]"; -constexpr char kSize[] = "size"; - -constexpr char kTernary[] = "_?_:_"; - -// Timestamp and Duration -constexpr char kDuration[] = "duration"; -constexpr char kTimestamp[] = "timestamp"; -constexpr char kFullYear[] = "getFullYear"; -constexpr char kMonth[] = "getMonth"; -constexpr char kDayOfYear[] = "getDayOfYear"; -constexpr char kDayOfMonth[] = "getDayOfMonth"; -constexpr char kDate[] = "getDate"; -constexpr char kDayOfWeek[] = "getDayOfWeek"; -constexpr char kHours[] = "getHours"; -constexpr char kMinutes[] = "getMinutes"; -constexpr char kSeconds[] = "getSeconds"; -constexpr char kMilliseconds[] = "getMilliseconds"; - -// Type conversions -// TODO(issues/23): Add other type conversion methods. -constexpr char kBytes[] = "bytes"; -constexpr char kDouble[] = "double"; -constexpr char kDyn[] = "dyn"; -constexpr char kInt[] = "int"; -constexpr char kString[] = "string"; -constexpr char kType[] = "type"; -constexpr char kUint[] = "uint"; - -// Runtime-only functions. -// The convention for runtime-only functions where only the runtime needs to -// differentiate behavior is to prefix the function with `#`. -// Note, this is a different convention from CEL internal functions where the -// whole stack needs to be aware of the function id. -constexpr char kRuntimeListAppend[] = "#list_append"; - -} // namespace builtin +// Alias new namespace until external CEL users can be updated. +namespace builtin = cel::builtin; } // namespace runtime } // namespace expr From 940bc5817477c7cea67da63727848c34f917b9e3 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 28 Apr 2023 17:07:59 +0000 Subject: [PATCH 230/303] Cleanup comparison functions registrar file before moving. PiperOrigin-RevId: 527916991 --- eval/public/BUILD | 7 +-- eval/public/cel_function_registry.h | 2 + eval/public/comparison_functions.cc | 73 ++++++++++++++--------------- 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 91acad164..11769cda5 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -323,17 +323,18 @@ cc_library( "comparison_functions.h", ], deps = [ - ":cel_builtins", ":cel_function_registry", - ":cel_number", ":cel_options", + "//base:builtins", "//base:function_adapter", "//base:handle", "//base:value", "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime/internal:number", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", ], ) diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index ced74f617..e8d484605 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -122,6 +122,8 @@ class CelFunctionRegistry { return modern_registry_; } + cel::FunctionRegistry& InternalGetRegistry() { return modern_registry_; } + private: cel::FunctionRegistry modern_registry_; diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index df48fbf85..74c237230 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -14,27 +14,22 @@ #include "eval/public/comparison_functions.h" -#include #include -#include -#include -#include -#include -#include #include "absl/status/status.h" #include "absl/time/time.h" -#include "absl/types/optional.h" +#include "base/builtins.h" #include "base/function_adapter.h" #include "base/handle.h" #include "base/value_factory.h" #include "base/values/bytes_value.h" #include "base/values/string_value.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_function_registry.h" -#include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/internal/number.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -45,6 +40,7 @@ using ::cel::BytesValue; using ::cel::Handle; using ::cel::StringValue; using ::cel::ValueFactory; +using ::cel::runtime_internal::Number; // Comparison template functions template @@ -161,48 +157,49 @@ bool GreaterThanOrEqual(ValueFactory&, absl::Time t1, absl::Time t2) { template bool CrossNumericLessThan(ValueFactory&, T t, U u) { - return CelNumber(t) < CelNumber(u); + return Number(t) < Number(u); } template bool CrossNumericGreaterThan(ValueFactory&, T t, U u) { - return CelNumber(t) > CelNumber(u); + return Number(t) > Number(u); } template bool CrossNumericLessOrEqualTo(ValueFactory&, T t, U u) { - return CelNumber(t) <= CelNumber(u); + return Number(t) <= Number(u); } template bool CrossNumericGreaterOrEqualTo(ValueFactory&, T t, U u) { - return CelNumber(t) >= CelNumber(u); + return Number(t) >= Number(u); } template -absl::Status RegisterComparisonFunctionsForType(CelFunctionRegistry* registry) { +absl::Status RegisterComparisonFunctionsForType( + cel::FunctionRegistry& registry) { using FunctionAdapter = BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kLess, false), + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), FunctionAdapter::WrapFunction(LessThan))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kLessOrEqual, false), + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), FunctionAdapter::WrapFunction(LessThanOrEqual))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kGreater, false), + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), FunctionAdapter::WrapFunction(GreaterThan))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kGreaterOrEqual, false), + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), FunctionAdapter::WrapFunction(GreaterThanOrEqual))); return absl::OkStatus(); } absl::Status RegisterHomogenousComparisonFunctions( - CelFunctionRegistry* registry) { + cel::FunctionRegistry& registry) { CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); @@ -226,30 +223,29 @@ absl::Status RegisterHomogenousComparisonFunctions( } template -absl::Status RegisterCrossNumericComparisons(CelFunctionRegistry* registry) { +absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { using FunctionAdapter = BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kLess, + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericLessThan))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kGreater, + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kGreaterOrEqual, + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); - CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kLessOrEqual, + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); return absl::OkStatus(); } absl::Status RegisterHeterogeneousComparisonFunctions( - CelFunctionRegistry* registry) { - + cel::FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( @@ -284,10 +280,13 @@ absl::Status RegisterHeterogeneousComparisonFunctions( absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { - if (options.enable_heterogeneous_equality) { - CEL_RETURN_IF_ERROR(RegisterHeterogeneousComparisonFunctions(registry)); + cel::RuntimeOptions modern_options = ConvertToRuntimeOptions(options); + cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); + if (modern_options.enable_heterogeneous_equality) { + CEL_RETURN_IF_ERROR( + RegisterHeterogeneousComparisonFunctions(modern_registry)); } else { - CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(registry)); + CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(modern_registry)); } return absl::OkStatus(); } From 23197c03b0574410513d77bde3ea5f103ae2bd09 Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 28 Apr 2023 17:08:30 +0000 Subject: [PATCH 231/303] Early return on list concat against an empty input PiperOrigin-RevId: 527917187 --- eval/public/builtin_func_registrar.cc | 23 ++++++++++++++-------- eval/public/builtin_func_registrar_test.cc | 22 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index a69af1d89..fc6c23d12 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -394,12 +394,18 @@ absl::StatusOr> ConcatBytes(ValueFactory& factory, // Concatenation for CelList type. absl::StatusOr> ConcatList(ValueFactory& factory, - const ListValue& value1, - const ListValue& value2) { + const Handle& value1, + const Handle& value2) { std::vector joined_values; - int size1 = value1.size(); - int size2 = value2.size(); + int size1 = value1->size(); + if (size1 == 0) { + return value2; + } + int size2 = value2->size(); + if (size2 == 0) { + return value1; + } joined_values.reserve(size1 + size2); Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( @@ -407,12 +413,12 @@ absl::StatusOr> ConcatList(ValueFactory& factory, ListValue::GetContext context(factory); for (int i = 0; i < size1; i++) { - CEL_ASSIGN_OR_RETURN(Handle elem, value1.Get(context, i)); + CEL_ASSIGN_OR_RETURN(Handle elem, value1->Get(context, i)); joined_values.push_back( cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); } for (int i = 0; i < size2; i++) { - CEL_ASSIGN_OR_RETURN(Handle elem, value2.Get(context, i)); + CEL_ASSIGN_OR_RETURN(Handle elem, value2->Get(context, i)); joined_values.push_back( cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); } @@ -1626,8 +1632,9 @@ absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, BinaryFunctionAdapter>, const ListValue&, const ListValue&>::CreateDescriptor(builtin::kAdd, false), - BinaryFunctionAdapter>, const ListValue&, - const ListValue&>::WrapFunction(ConcatList))); + BinaryFunctionAdapter< + absl::StatusOr>, const Handle&, + const Handle&>::WrapFunction(ConcatList))); } return registry->Register(PortableBinaryFunctionAdapter< diff --git a/eval/public/builtin_func_registrar_test.cc b/eval/public/builtin_func_registrar_test.cc index 042e1c645..1cc3144cf 100644 --- a/eval/public/builtin_func_registrar_test.cc +++ b/eval/public/builtin_func_registrar_test.cc @@ -244,6 +244,28 @@ INSTANTIATE_TEST_SUITE_P( {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, + + // List concatenation tests. + {"ListConcatEmptyInputs", + "[] + [] == []", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcatRightEmpty", + "[1] + [] == [1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcatLeftEmpty", + "[] + [1] == [1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, + {"ListConcat", + "[2] + [1] == [2, 1]", + {}, + CelValue::CreateBool(true), + OverflowChecksEnabled()}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; From be05243756921852a194fdabf1a517e974aa9bc4 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Fri, 28 Apr 2023 17:09:17 +0000 Subject: [PATCH 232/303] Move comparison functions registrar to runtime/standard. PiperOrigin-RevId: 527917420 --- eval/public/BUILD | 18 +- eval/public/comparison_functions.cc | 265 +--------------- eval/public/comparison_functions_test.cc | 71 +---- runtime/standard/BUILD | 56 ++++ runtime/standard/comparison_functions.cc | 283 ++++++++++++++++++ runtime/standard/comparison_functions.h | 36 +++ runtime/standard/comparison_functions_test.cc | 82 +++++ 7 files changed, 461 insertions(+), 350 deletions(-) create mode 100644 runtime/standard/BUILD create mode 100644 runtime/standard/comparison_functions.cc create mode 100644 runtime/standard/comparison_functions.h create mode 100644 runtime/standard/comparison_functions_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index 11769cda5..cd82c52be 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -325,16 +325,10 @@ cc_library( deps = [ ":cel_function_registry", ":cel_options", - "//base:builtins", - "//base:function_adapter", - "//base:handle", - "//base:value", - "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", - "//runtime/internal:number", + "//runtime/standard:comparison_functions", "@com_google_absl//absl/status", - "@com_google_absl//absl/time", ], ) @@ -346,29 +340,19 @@ cc_test( ], deps = [ ":activation", - ":cel_builtins", ":cel_expr_builder_factory", ":cel_expression", ":cel_function_registry", ":cel_options", ":cel_value", ":comparison_functions", - ":message_wrapper", - "//eval/public/containers:container_backed_list_impl", - "//eval/public/containers:container_backed_map_impl", - "//eval/public/structs:cel_proto_wrapper", - "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", - "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//parser", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 74c237230..ec282704c 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -14,281 +14,20 @@ #include "eval/public/comparison_functions.h" -#include - #include "absl/status/status.h" -#include "absl/time/time.h" -#include "base/builtins.h" -#include "base/function_adapter.h" -#include "base/handle.h" -#include "base/value_factory.h" -#include "base/values/bytes_value.h" -#include "base/values/string_value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" -#include "internal/status_macros.h" #include "runtime/function_registry.h" -#include "runtime/internal/number.h" #include "runtime/runtime_options.h" +#include "runtime/standard/comparison_functions.h" namespace google::api::expr::runtime { -namespace { - -using ::cel::BinaryFunctionAdapter; -using ::cel::BytesValue; -using ::cel::Handle; -using ::cel::StringValue; -using ::cel::ValueFactory; -using ::cel::runtime_internal::Number; - -// Comparison template functions -template -bool LessThan(ValueFactory&, Type t1, Type t2) { - return (t1 < t2); -} - -template -bool LessThanOrEqual(ValueFactory&, Type t1, Type t2) { - return (t1 <= t2); -} - -template -bool GreaterThan(ValueFactory& factory, Type t1, Type t2) { - return LessThan(factory, t2, t1); -} - -template -bool GreaterThanOrEqual(ValueFactory& factory, Type t1, Type t2) { - return LessThanOrEqual(factory, t2, t1); -} - -// String value comparions specializations -template <> -bool LessThan(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) < 0; -} - -template <> -bool LessThanOrEqual(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) <= 0; -} - -template <> -bool GreaterThan(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) > 0; -} - -template <> -bool GreaterThanOrEqual(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) >= 0; -} - -// bytes value comparions specializations -template <> -bool LessThan(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) < 0; -} - -template <> -bool LessThanOrEqual(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) <= 0; -} - -template <> -bool GreaterThan(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) > 0; -} - -template <> -bool GreaterThanOrEqual(ValueFactory&, const Handle& t1, - const Handle& t2) { - return t1->Compare(*t2) >= 0; -} - -// Duration comparison specializations -template <> -bool LessThan(ValueFactory&, absl::Duration t1, absl::Duration t2) { - return absl::operator<(t1, t2); -} - -template <> -bool LessThanOrEqual(ValueFactory&, absl::Duration t1, absl::Duration t2) { - return absl::operator<=(t1, t2); -} - -template <> -bool GreaterThan(ValueFactory&, absl::Duration t1, absl::Duration t2) { - return absl::operator>(t1, t2); -} - -template <> -bool GreaterThanOrEqual(ValueFactory&, absl::Duration t1, absl::Duration t2) { - return absl::operator>=(t1, t2); -} - -// Timestamp comparison specializations -template <> -bool LessThan(ValueFactory&, absl::Time t1, absl::Time t2) { - return absl::operator<(t1, t2); -} - -template <> -bool LessThanOrEqual(ValueFactory&, absl::Time t1, absl::Time t2) { - return absl::operator<=(t1, t2); -} - -template <> -bool GreaterThan(ValueFactory&, absl::Time t1, absl::Time t2) { - return absl::operator>(t1, t2); -} - -template <> -bool GreaterThanOrEqual(ValueFactory&, absl::Time t1, absl::Time t2) { - return absl::operator>=(t1, t2); -} - -template -bool CrossNumericLessThan(ValueFactory&, T t, U u) { - return Number(t) < Number(u); -} - -template -bool CrossNumericGreaterThan(ValueFactory&, T t, U u) { - return Number(t) > Number(u); -} - -template -bool CrossNumericLessOrEqualTo(ValueFactory&, T t, U u) { - return Number(t) <= Number(u); -} - -template -bool CrossNumericGreaterOrEqualTo(ValueFactory&, T t, U u) { - return Number(t) >= Number(u); -} - -template -absl::Status RegisterComparisonFunctionsForType( - cel::FunctionRegistry& registry) { - using FunctionAdapter = BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), - FunctionAdapter::WrapFunction(LessThan))); - - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), - FunctionAdapter::WrapFunction(LessThanOrEqual))); - - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), - FunctionAdapter::WrapFunction(GreaterThan))); - - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), - FunctionAdapter::WrapFunction(GreaterThanOrEqual))); - - return absl::OkStatus(); -} - -absl::Status RegisterHomogenousComparisonFunctions( - cel::FunctionRegistry& registry) { - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType&>(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType&>(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - return absl::OkStatus(); -} - -template -absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { - using FunctionAdapter = BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLess, - /*receiver_style=*/false), - FunctionAdapter::WrapFunction(&CrossNumericLessThan))); - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, - /*receiver_style=*/false), - FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, - /*receiver_style=*/false), - FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, - /*receiver_style=*/false), - FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); - return absl::OkStatus(); -} - -absl::Status RegisterHeterogeneousComparisonFunctions( - cel::FunctionRegistry& registry) { - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType&>(registry)); - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType&>(registry)); - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - return absl::OkStatus(); -} -} // namespace - - absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { cel::RuntimeOptions modern_options = ConvertToRuntimeOptions(options); cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); - if (modern_options.enable_heterogeneous_equality) { - CEL_RETURN_IF_ERROR( - RegisterHeterogeneousComparisonFunctions(modern_registry)); - } else { - CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(modern_registry)); - } - return absl::OkStatus(); + return cel::RegisterComparisonFunctions(modern_registry, modern_options); } } // namespace google::api::expr::runtime diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index 8fedbf5d1..da2807cb4 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -14,43 +14,23 @@ #include "eval/public/comparison_functions.h" -#include -#include -#include #include -#include -#include +#include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/any.pb.h" #include "google/rpc/context/attribute_context.pb.h" -#include "google/protobuf/descriptor.pb.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/dynamic_message.h" -#include "google/protobuf/message.h" -#include "google/protobuf/text_format.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" #include "eval/public/activation.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/container_backed_list_impl.h" -#include "eval/public/containers/container_backed_map_impl.h" -#include "eval/public/message_wrapper.h" -#include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" -#include "eval/testutil/test_message.pb.h" // IWYU pragma: keep #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" @@ -60,13 +40,8 @@ namespace { using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::rpc::context::AttributeContext; -using testing::_; using testing::Combine; -using testing::HasSubstr; -using testing::Optional; -using testing::Values; using testing::ValuesIn; -using cel::internal::StatusIs; MATCHER_P2(DefinesHomogenousOverload, name, argument_type, absl::StrCat(name, " for ", CelValue::TypeName(argument_type))) { @@ -116,50 +91,6 @@ class ComparisonFunctionTest google::protobuf::Arena arena_; }; -constexpr std::array kOrderableTypes = { - CelValue::Type::kBool, CelValue::Type::kInt64, - CelValue::Type::kUint64, CelValue::Type::kString, - CelValue::Type::kDouble, CelValue::Type::kBytes, - CelValue::Type::kDuration, CelValue::Type::kTimestamp}; - -TEST(RegisterComparisonFunctionsTest, LessThanDefined) { - InterpreterOptions default_options; - CelFunctionRegistry registry; - ASSERT_OK(RegisterComparisonFunctions(®istry, default_options)); - for (CelValue::Type type : kOrderableTypes) { - EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kLess, type)); - } -} - -TEST(RegisterComparisonFunctionsTest, LessThanOrEqualDefined) { - InterpreterOptions default_options; - CelFunctionRegistry registry; - ASSERT_OK(RegisterComparisonFunctions(®istry, default_options)); - for (CelValue::Type type : kOrderableTypes) { - EXPECT_THAT(registry, - DefinesHomogenousOverload(builtin::kLessOrEqual, type)); - } -} - -TEST(RegisterComparisonFunctionsTest, GreaterThanDefined) { - InterpreterOptions default_options; - CelFunctionRegistry registry; - ASSERT_OK(RegisterComparisonFunctions(®istry, default_options)); - for (CelValue::Type type : kOrderableTypes) { - EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kGreater, type)); - } -} - -TEST(RegisterComparisonFunctionsTest, GreaterThanOrEqualDefined) { - InterpreterOptions default_options; - CelFunctionRegistry registry; - ASSERT_OK(RegisterComparisonFunctions(®istry, default_options)); - for (CelValue::Type type : kOrderableTypes) { - EXPECT_THAT(registry, - DefinesHomogenousOverload(builtin::kGreaterOrEqual, type)); - } -} - TEST_P(ComparisonFunctionTest, SmokeTest) { ComparisonTestCase test_case = std::get<0>(GetParam()); google::protobuf::LinkMessageReflection(); diff --git a/runtime/standard/BUILD b/runtime/standard/BUILD new file mode 100644 index 000000000..02b46be20 --- /dev/null +++ b/runtime/standard/BUILD @@ -0,0 +1,56 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Provides registrars for CEL standard definitions. +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "comparison_functions", + srcs = [ + "comparison_functions.cc", + ], + hdrs = [ + "comparison_functions.h", + ], + deps = [ + "//base:builtins", + "//base:function_adapter", + "//base:handle", + "//base:value", + "//internal:status_macros", + "//runtime:function_registry", + "//runtime:runtime_options", + "//runtime/internal:number", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "comparison_functions_test", + size = "small", + srcs = [ + "comparison_functions_test.cc", + ], + deps = [ + ":comparison_functions", + "//base:builtins", + "//base:kind", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) diff --git a/runtime/standard/comparison_functions.cc b/runtime/standard/comparison_functions.cc new file mode 100644 index 000000000..b3de4ac42 --- /dev/null +++ b/runtime/standard/comparison_functions.cc @@ -0,0 +1,283 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/comparison_functions.h" + +#include + +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "base/handle.h" +#include "base/value_factory.h" +#include "base/values/bytes_value.h" +#include "base/values/string_value.h" +#include "internal/status_macros.h" +#include "runtime/function_registry.h" +#include "runtime/internal/number.h" +#include "runtime/runtime_options.h" + +namespace cel { + +namespace { + +using ::cel::runtime_internal::Number; + +// Comparison template functions +template +bool LessThan(ValueFactory&, Type t1, Type t2) { + return (t1 < t2); +} + +template +bool LessThanOrEqual(ValueFactory&, Type t1, Type t2) { + return (t1 <= t2); +} + +template +bool GreaterThan(ValueFactory& factory, Type t1, Type t2) { + return LessThan(factory, t2, t1); +} + +template +bool GreaterThanOrEqual(ValueFactory& factory, Type t1, Type t2) { + return LessThanOrEqual(factory, t2, t1); +} + +// String value comparions specializations +template <> +bool LessThan(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) < 0; +} + +template <> +bool LessThanOrEqual(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) <= 0; +} + +template <> +bool GreaterThan(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) > 0; +} + +template <> +bool GreaterThanOrEqual(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) >= 0; +} + +// bytes value comparions specializations +template <> +bool LessThan(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) < 0; +} + +template <> +bool LessThanOrEqual(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) <= 0; +} + +template <> +bool GreaterThan(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) > 0; +} + +template <> +bool GreaterThanOrEqual(ValueFactory&, const Handle& t1, + const Handle& t2) { + return t1->Compare(*t2) >= 0; +} + +// Duration comparison specializations +template <> +bool LessThan(ValueFactory&, absl::Duration t1, absl::Duration t2) { + return absl::operator<(t1, t2); +} + +template <> +bool LessThanOrEqual(ValueFactory&, absl::Duration t1, absl::Duration t2) { + return absl::operator<=(t1, t2); +} + +template <> +bool GreaterThan(ValueFactory&, absl::Duration t1, absl::Duration t2) { + return absl::operator>(t1, t2); +} + +template <> +bool GreaterThanOrEqual(ValueFactory&, absl::Duration t1, absl::Duration t2) { + return absl::operator>=(t1, t2); +} + +// Timestamp comparison specializations +template <> +bool LessThan(ValueFactory&, absl::Time t1, absl::Time t2) { + return absl::operator<(t1, t2); +} + +template <> +bool LessThanOrEqual(ValueFactory&, absl::Time t1, absl::Time t2) { + return absl::operator<=(t1, t2); +} + +template <> +bool GreaterThan(ValueFactory&, absl::Time t1, absl::Time t2) { + return absl::operator>(t1, t2); +} + +template <> +bool GreaterThanOrEqual(ValueFactory&, absl::Time t1, absl::Time t2) { + return absl::operator>=(t1, t2); +} + +template +bool CrossNumericLessThan(ValueFactory&, T t, U u) { + return Number(t) < Number(u); +} + +template +bool CrossNumericGreaterThan(ValueFactory&, T t, U u) { + return Number(t) > Number(u); +} + +template +bool CrossNumericLessOrEqualTo(ValueFactory&, T t, U u) { + return Number(t) <= Number(u); +} + +template +bool CrossNumericGreaterOrEqualTo(ValueFactory&, T t, U u) { + return Number(t) >= Number(u); +} + +template +absl::Status RegisterComparisonFunctionsForType( + cel::FunctionRegistry& registry) { + using FunctionAdapter = BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), + FunctionAdapter::WrapFunction(LessThan))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), + FunctionAdapter::WrapFunction(LessThanOrEqual))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), + FunctionAdapter::WrapFunction(GreaterThan))); + + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), + FunctionAdapter::WrapFunction(GreaterThanOrEqual))); + + return absl::OkStatus(); +} + +absl::Status RegisterHomogenousComparisonFunctions( + cel::FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType&>(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType&>(registry)); + + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + return absl::OkStatus(); +} + +template +absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { + using FunctionAdapter = BinaryFunctionAdapter; + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLess, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericLessThan))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); + CEL_RETURN_IF_ERROR(registry.Register( + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, + /*receiver_style=*/false), + FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); + return absl::OkStatus(); +} + +absl::Status RegisterHeterogeneousComparisonFunctions( + cel::FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry))); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType&>(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType&>(registry)); + CEL_RETURN_IF_ERROR( + RegisterComparisonFunctionsForType(registry)); + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + + return absl::OkStatus(); +} +} // namespace + +absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_heterogeneous_equality) { + CEL_RETURN_IF_ERROR(RegisterHeterogeneousComparisonFunctions(registry)); + } else { + CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(registry)); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/runtime/standard/comparison_functions.h b/runtime/standard/comparison_functions.h new file mode 100644 index 000000000..4b19f85ed --- /dev/null +++ b/runtime/standard/comparison_functions.h @@ -0,0 +1,36 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel { + +// Register built in comparison functions (<, <=, >, >=). +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This is call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same +// registry will result in an error. +absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ diff --git a/runtime/standard/comparison_functions_test.cc b/runtime/standard/comparison_functions_test.cc new file mode 100644 index 000000000..cd85fdeec --- /dev/null +++ b/runtime/standard/comparison_functions_test.cc @@ -0,0 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/standard/comparison_functions.h" + +#include + +#include "absl/strings/str_cat.h" +#include "base/builtins.h" +#include "base/kind.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +MATCHER_P2(DefinesHomogenousOverload, name, argument_kind, + absl::StrCat(name, " for ", KindToString(argument_kind))) { + const cel::FunctionRegistry& registry = arg; + return !registry + .FindStaticOverloads(name, /*receiver_style=*/false, + {argument_kind, argument_kind}) + .empty(); +} + +constexpr std::array kOrderableTypes = { + Kind::kBool, Kind::kInt64, Kind::kUint64, Kind::kString, + Kind::kDouble, Kind::kBytes, Kind::kDuration, Kind::kTimestamp}; + +TEST(RegisterComparisonFunctionsTest, LessThanDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kLess, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, LessThanOrEqualDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, + DefinesHomogenousOverload(builtin::kLessOrEqual, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, GreaterThanDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kGreater, kind)); + } +} + +TEST(RegisterComparisonFunctionsTest, GreaterThanOrEqualDefined) { + RuntimeOptions default_options; + FunctionRegistry registry; + ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); + for (Kind kind : kOrderableTypes) { + EXPECT_THAT(registry, + DefinesHomogenousOverload(builtin::kGreaterOrEqual, kind)); + } +} + +// TODO(issues/5): move functional tests from wrapper library after top-level +// APIs are available for planning and running an expression. + +} // namespace +} // namespace cel From 5f1f4e88cf6bff5259233b047a9fdd086e344ba0 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 1 May 2023 14:48:54 +0000 Subject: [PATCH 233/303] Rename `Referent` to `Borrowed` and introduce `Owner` PiperOrigin-RevId: 528470621 --- base/BUILD | 9 ++ base/handle.h | 38 ------- base/internal/data.h | 25 +++++ base/owner.h | 165 ++++++++++++++++++++++++++++ base/type.h | 17 ++- base/value.cc | 2 - base/value.h | 12 +- base/value_factory.h | 55 +++++----- base/values/list_value.h | 3 +- base/values/map_value.h | 3 +- base/values/struct_value.h | 3 +- extensions/protobuf/struct_value.cc | 128 ++++++++++----------- 12 files changed, 313 insertions(+), 147 deletions(-) create mode 100644 base/owner.h diff --git a/base/BUILD b/base/BUILD index 578288651..eacde47b1 100644 --- a/base/BUILD +++ b/base/BUILD @@ -72,6 +72,14 @@ cc_library( ], ) +cc_library( + name = "owner", + hdrs = ["owner.h"], + deps = [ + "//base/internal:data", + ], +) + cc_library( name = "kind", srcs = ["kind.cc"], @@ -256,6 +264,7 @@ cc_library( ":handle", ":kind", ":memory_manager", + ":owner", ":type", "//base/internal:data", "//base/internal:message_wrapper", diff --git a/base/handle.h b/base/handle.h index 4a846b4e2..2d14e8353 100644 --- a/base/handle.h +++ b/base/handle.h @@ -256,8 +256,6 @@ class Handle final : private base_internal::HandlePolicy { explicit Handle(absl::in_place_t, Args&&... args) : impl_(std::forward(args)...) {} - T* Release() { return static_cast(impl_.release()); } - Impl impl_; }; @@ -312,42 +310,6 @@ struct HandleFactory { return Handle(absl::in_place, *base_internal::ManagedMemoryRelease(managed_memory)); } - - template - static std::enable_if_t, Handle> MakeFrom( - const F* data) { - if (Metadata::IsReferenceCounted(*data)) { - Metadata::Ref(*data); - } else { - ABSL_ASSERT(Metadata::IsArenaAllocated(*data)); - } - return Handle(absl::in_place, *const_cast(data)); - } - - // Requires that T is reference counted or arena allocated. Clears Handle - // without decrementing the reference count. Returns nullptr if T is arena - // allocated, T* otherwise. - template - static std::enable_if_t, F*> Release( - Handle& handle) { - T* data = handle.Release(); - if (Metadata::IsArenaAllocated(*data)) { - data = nullptr; - } else { - ABSL_ASSERT(Metadata::IsReferenceCounted(*data)); - } - return data; - } -}; - -template -class EnableHandleFromThis { - protected: - Handle handle_from_this() const { - static_assert(std::is_base_of_v, T>); - // It is guaranteed that we are either reference counted or arena allocated. - return HandleFactory::MakeFrom(reinterpret_cast(this)); - } }; } // namespace cel::base_internal diff --git a/base/internal/data.h b/base/internal/data.h index 3a698498b..270c33975 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -31,6 +31,8 @@ namespace cel { +class Type; +class Value; class MemoryManager; namespace base_internal { @@ -343,6 +345,29 @@ class Metadata final { Metadata& operator=(Metadata&&) = delete; }; +class TypeMetadata; +class ValueMetadata; + +template +struct SelectMetadataImpl; + +template +struct SelectMetadataImpl< + T, std::enable_if_t, + std::is_base_of>>> { + using type = TypeMetadata; +}; + +template +struct SelectMetadataImpl< + T, std::enable_if_t, + std::is_base_of>>> { + using type = ValueMetadata; +}; + +template +using SelectMetadata = typename SelectMetadataImpl::type; + template union alignas(Align) AnyDataStorage final { AnyDataStorage() : pointer(0) {} diff --git a/base/owner.h b/base/owner.h new file mode 100644 index 000000000..7c41f14e6 --- /dev/null +++ b/base/owner.h @@ -0,0 +1,165 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_OWNER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_OWNER_H_ + +#include + +#include "base/internal/data.h" + +namespace cel { + +class Type; +class Value; +class TypeFactory; +class ValueFactory; + +template +class EnableOwnerFromThis; + +// Owner is a special type used for creating borrowed types and values. It +// represents the actual owner of the data. The created borrowed value will +// ensure that the Owner is alive for as long as it is alive. +template +class Owner { + private: + using metadata_type = base_internal::SelectMetadata; + + public: + Owner() = delete; + + Owner(const Owner& other) noexcept : owner_(other.owner_) { + if (owner_ != nullptr) { + metadata_type::Ref(*owner_); + } + } + + Owner(Owner&& other) noexcept : owner_(other.owner_) { + other.owner_ = nullptr; + } + + template >> + Owner(const Owner& other) noexcept : owner_(other.owner_) { // NOLINT + if (owner_ != nullptr) { + metadata_type::Ref(*owner_); + } + } + + template >> + Owner(Owner&& other) : owner_(other.owner_) { // NOLINT + other.owner_ = nullptr; + } + + Owner& operator=(const Owner& other) noexcept { + if (this != &other) { + if (static_cast(other)) { + metadata_type::Ref(*other.owner_); + } + if (static_cast(*this)) { + metadata_type::Unref(*owner_); + } + owner_ = other.owner_; + } + return *this; + } + + Owner& operator=(Owner&& other) noexcept { + if (this != &other) { + if (static_cast(*this)) { + metadata_type::Unref(*owner_); + } + owner_ = other.owner_; + other.owner_ = nullptr; + } + return *this; + } + + template >> + Owner& operator=(const Owner& other) noexcept { + if (this != &other) { + if (static_cast(other)) { + metadata_type::Ref(*other.owner_); + } + if (static_cast(*this)) { + metadata_type::Unref(*owner_); + } + owner_ = other.owner_; + } + return *this; + } + + template >> + Owner& operator=(Owner&& other) { // NOLINT + if (this != &other) { + if (static_cast(*this)) { + metadata_type::Unref(*owner_); + } + owner_ = other.owner_; + other.owner_ = nullptr; + } + return *this; + } + + ~Owner() { + if (static_cast(*this)) { + metadata_type::Unref(*owner_); + } + } + + explicit operator bool() const { return owner_ != nullptr; } + + private: + template + friend class Owner; + template + friend class EnableOwnerFromThis; + friend class TypeFactory; + friend class ValueFactory; + + explicit Owner(const T* owner) : owner_(owner) {} + + const T* release() { + const T* owner = owner_; + owner_ = nullptr; + return owner; + } + + const T* owner_; +}; + +template +class EnableOwnerFromThis { + protected: + Owner owner_from_this() const { + static_assert(std::is_base_of_v, T>); + static_assert(std::is_base_of_v); + using metadata_type = base_internal::SelectMetadata; + const T* owner = reinterpret_cast(this); + if (metadata_type::IsReferenceCounted(*owner)) { + metadata_type::Ref(*owner); + } else { + owner = nullptr; + } + return Owner(owner); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_OWNER_H_ diff --git a/base/type.h b/base/type.h index 1df85a36d..123a4c62d 100644 --- a/base/type.h +++ b/base/type.h @@ -133,6 +133,17 @@ namespace cel { namespace base_internal { +class TypeMetadata final { + public: + TypeMetadata() = delete; + + static void Ref(const Type& type); + + static void Unref(const Type& type); + + static bool IsReferenceCounted(const Type& type); +}; + class TypeHandle final { public: TypeHandle() = default; @@ -172,12 +183,6 @@ class TypeHandle final { void HashValue(absl::HashState state) const; - Type* release() { - Type* type = static_cast(data_.get_heap()); - data_.set_pointer(0); - return type; - } - private: static bool Equals(const Type& lhs, const Type& rhs, Kind kind); diff --git a/base/value.cc b/base/value.cc index ed7e9e25d..571c11606 100644 --- a/base/value.cc +++ b/base/value.cc @@ -456,8 +456,6 @@ void ValueHandle::Delete(Kind kind, const Value& value) { } } -void ValueMetadata::Ref(const Value& value) { Metadata::Ref(value); } - void ValueMetadata::Unref(const Value& value) { if (Metadata::Unref(value)) { ValueHandle::Delete(Metadata::KindHeap(value), value); diff --git a/base/value.h b/base/value.h index 3ebde96a2..9155e821f 100644 --- a/base/value.h +++ b/base/value.h @@ -124,9 +124,13 @@ class ValueMetadata final { public: ValueMetadata() = delete; - static void Ref(const Value& value); + static void Ref(const Value& value) { Metadata::Ref(value); } static void Unref(const Value& value); + + static bool IsReferenceCounted(const Value& value) { + return Metadata::IsReferenceCounted(value); + } }; class ValueHandle final { @@ -166,12 +170,6 @@ class ValueHandle final { bool Equals(const ValueHandle& other) const; - Value* release() { - Value* value = static_cast(data_.get_heap()); - data_.set_pointer(0); - return value; - } - private: friend class ValueMetadata; diff --git a/base/value_factory.h b/base/value_factory.h index 7685ad872..5f213bbe0 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -32,6 +32,7 @@ #include "base/function_result_set.h" #include "base/handle.h" #include "base/memory_manager.h" +#include "base/owner.h" #include "base/type_manager.h" #include "base/value.h" #include "base/values/bool_value.h" @@ -57,18 +58,18 @@ namespace cel { namespace base_internal { template -class ReferentValue final : public T { +class BorrowedValue final : public T { public: template - explicit ReferentValue(const cel::Value* referent, Args&&... args) + explicit BorrowedValue(const cel::Value* owner, Args&&... args) : T(std::forward(args)...), - referent_(ABSL_DIE_IF_NULL(referent)) // Crash OK + owner_(ABSL_DIE_IF_NULL(owner)) // Crash OK {} - ~ReferentValue() override { ValueMetadata::Unref(*referent_); } + ~BorrowedValue() override { ValueMetadata::Unref(*owner_); } private: - const cel::Value* const referent_; + const cel::Value* const owner_; }; } // namespace base_internal @@ -163,11 +164,11 @@ class ValueFactory final { template EnableIfReferent>> - CreateReferentBytesValue(Handle reference, absl::string_view value) { + CreateBorrowedBytesValue(Owner owner, absl::string_view value) { if (value.empty()) { return GetEmptyBytesValue(); } - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return base_internal::HandleFactory::Make< base_internal::InlinedStringViewBytesValue>(value); @@ -215,11 +216,11 @@ class ValueFactory final { template EnableIfReferent>> - CreateReferentStringValue(Handle reference, absl::string_view value) { + CreateBorrowedStringValue(Owner owner, absl::string_view value) { if (value.empty()) { return GetEmptyStringValue(); } - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return base_internal::HandleFactory::Make< base_internal::InlinedStringViewStringValue>(value); @@ -295,27 +296,27 @@ class ValueFactory final { template EnableIfBaseOfAndReferent>> - CreateReferentStructValue(Handle reference, const Handle& type, + CreateBorrowedStructValue(Owner owner, const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateStructValue(type, std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, type, + base_internal::BorrowedValue>(memory_manager(), pointer, type, std::forward(args)...); } template EnableIfBaseOfAndReferent>> - CreateReferentStructValue(Handle reference, Handle&& type, + CreateBorrowedStructValue(Owner owner, Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateStructValue(std::move(type), std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, + base_internal::BorrowedValue>(memory_manager(), pointer, std::move(type), std::forward(args)...); } @@ -339,27 +340,27 @@ class ValueFactory final { template EnableIfBaseOfAndReferent>> - CreateReferentListValue(Handle reference, const Handle& type, + CreateBorrowedListValue(Owner owner, const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateListValue(type, std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, type, + base_internal::BorrowedValue>(memory_manager(), pointer, type, std::forward(args)...); } template EnableIfBaseOfAndReferent>> - CreateReferentListValue(Handle reference, Handle&& type, + CreateBorrowedListValue(Owner owner, Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateListValue(std::move(type), std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, + base_internal::BorrowedValue>(memory_manager(), pointer, std::move(type), std::forward(args)...); } @@ -383,27 +384,27 @@ class ValueFactory final { template EnableIfBaseOfAndReferent>> - CreateReferentMapValue(Handle reference, const Handle& type, + CreateBorrowedMapValue(Owner owner, const Handle& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateMapValue(type, std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, type, + base_internal::BorrowedValue>(memory_manager(), pointer, type, std::forward(args)...); } template EnableIfBaseOfAndReferent>> - CreateReferentMapValue(Handle reference, Handle&& type, + CreateBorrowedMapValue(Owner owner, Handle&& type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - auto* pointer = base_internal::HandleFactory::Release(reference); + auto* pointer = owner.release(); if (pointer == nullptr) { return CreateMapValue(std::move(type), std::forward(args)...); } return base_internal::HandleFactory::template Make< - base_internal::ReferentValue>(memory_manager(), pointer, + base_internal::BorrowedValue>(memory_manager(), pointer, std::move(type), std::forward(args)...); } diff --git a/base/values/list_value.h b/base/values/list_value.h index a14b5a0ce..90590957e 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -29,6 +29,7 @@ #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/owner.h" #include "base/type.h" #include "base/types/list_type.h" #include "base/value.h" @@ -160,7 +161,7 @@ class LegacyListValue final : public ListValue, public InlineData { class AbstractListValue : public ListValue, public HeapData, - public EnableHandleFromThis { + public EnableOwnerFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/base/values/map_value.h b/base/values/map_value.h index 4c46be9df..ba3ac382c 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -26,6 +26,7 @@ #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/owner.h" #include "base/type.h" #include "base/types/map_type.h" #include "base/value.h" @@ -187,7 +188,7 @@ class LegacyMapValue final : public MapValue, public InlineData { class AbstractMapValue : public MapValue, public HeapData, - public EnableHandleFromThis { + public EnableOwnerFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 7d0186147..596b5e0af 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -27,6 +27,7 @@ #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/owner.h" #include "base/type.h" #include "base/types/struct_type.h" #include "base/value.h" @@ -229,7 +230,7 @@ class LegacyStructValue final : public StructValue, public InlineData { class AbstractStructValue : public StructValue, public HeapData, - public EnableHandleFromThis { + public EnableOwnerFromThis { public: static bool Is(const Value& value) { return value.kind() == kKind && diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index e7cb7c02a..76224c34e 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -758,9 +758,9 @@ class ParsedProtoListValue // Scratch was not used, we can avoid copying. scratch.reset(); return context.value_factory() - .CreateReferentStructValue< + .CreateBorrowedStructValue< protobuf_internal::DynamicMemberParsedProtoStructValue>( - handle_from_this(), type()->element().As(), &field); + owner_from_this(), type()->element().As(), &field); } if (ProtoMemoryManager::Is(context.value_factory().memory_manager())) { auto* arena = ProtoMemoryManager::CastToProtoArena( @@ -1014,8 +1014,8 @@ class ParsedProtoListValue CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(field)); if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateReferentStringValue( - handle_from_this(), absl::get(wrapped)); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), absl::get(wrapped)); } else { ABSL_ASSERT(absl::holds_alternative(wrapped)); return context.value_factory().CreateStringValue( @@ -1297,8 +1297,8 @@ class ParsedProtoMapValueKeysList : public CEL_LIST_VALUE_CLASS { case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: return context.value_factory().CreateUintValue(key.GetUInt32Value()); case google::protobuf::FieldDescriptor::CPPTYPE_STRING: - return context.value_factory().CreateReferentStringValue( - handle_from_this(), key.GetStringValue()); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), key.GetStringValue()); case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return context.value_factory().CreateBoolValue(key.GetBoolValue()); default: @@ -1376,11 +1376,11 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_value.GetDoubleValue()); case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { if (value_desc->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { - return context.value_factory().CreateReferentBytesValue( - handle_from_this(), proto_value.GetStringValue()); + return context.value_factory().CreateBorrowedBytesValue( + owner_from_this(), proto_value.GetStringValue()); } else { - return context.value_factory().CreateReferentStringValue( - handle_from_this(), proto_value.GetStringValue()); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), proto_value.GetStringValue()); } } case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { @@ -1460,8 +1460,8 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { protobuf_internal::UnwrapStringValueProto( proto_value.GetMessageValue())); if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateReferentStringValue( - handle_from_this(), absl::get(wrapped)); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), absl::get(wrapped)); } else { ABSL_ASSERT(absl::holds_alternative(wrapped)); return context.value_factory().CreateStringValue( @@ -1478,9 +1478,9 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { } case Kind::kStruct: return context.value_factory() - .CreateReferentStructValue< + .CreateBorrowedStructValue< protobuf_internal::DynamicMemberParsedProtoStructValue>( - handle_from_this(), std::move(type).As(), + owner_from_this(), std::move(type).As(), &proto_value.GetMessageValue()); default: return absl::InternalError(absl::StrCat( @@ -1522,8 +1522,8 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { keys.push_back(begin.GetKey()); } return context.value_factory() - .CreateReferentListValue( - handle_from_this(), std::move(list_type), std::move(keys)); + .CreateBorrowedListValue( + owner_from_this(), std::move(list_type), std::move(keys)); } private: @@ -2028,8 +2028,8 @@ absl::StatusOr> ParsedProtoStructValue::GetMapField( const GetFieldContext& context, const StructType::Field& field, const google::protobuf::Reflection& reflect, const google::protobuf::FieldDescriptor& field_desc) const { - return context.value_factory().CreateReferentMapValue( - handle_from_this(), field.type.As(), value(), field_desc); + return context.value_factory().CreateBorrowedMapValue( + owner_from_this(), field.type.As(), value(), field_desc); } absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( @@ -2039,13 +2039,13 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( switch (field_desc.type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_FLOAT: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_INT64: ABSL_FALLTHROUGH_INTENDED; @@ -2053,8 +2053,8 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_INT32: ABSL_FALLTHROUGH_INTENDED; @@ -2062,33 +2062,33 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_UINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_FIXED64: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_BOOL: return context.value_factory() - .CreateReferentListValue>( - handle_from_this(), field.type.As(), + .CreateBorrowedListValue>( + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_STRING: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; @@ -2096,77 +2096,77 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( switch (field.type.As()->element()->kind()) { case Kind::kDuration: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kTimestamp: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kBool: // google.protobuf.BoolValue, mapped to CEL primitive bool type for // list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kBytes: // google.protobuf.BytesValue, mapped to CEL primitive bytes type for // list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kDouble: // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive // double type for list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kInt: // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive // int type for list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kString: // google.protobuf.StringValue, mapped to CEL primitive bytes type for // list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kUint: // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL primitive // uint type for list elements. return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case Kind::kStruct: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); default: @@ -2174,9 +2174,9 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( } case google::protobuf::FieldDescriptor::TYPE_BYTES: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type.As()->element()->kind()) { @@ -2188,9 +2188,9 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( .size()); case Kind::kEnum: return context.value_factory() - .CreateReferentListValue< + .CreateBorrowedListValue< ParsedProtoListValue>( - handle_from_this(), field.type.As(), + owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); default: ABSL_UNREACHABLE(); @@ -2242,8 +2242,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateStringValue( reflect.GetCord(value(), &field_desc)); } else { - return context.value_factory().CreateReferentStringValue( - handle_from_this(), reflect.GetStringView(value(), &field_desc)); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), reflect.GetStringView(value(), &field_desc)); } case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; @@ -2305,8 +2305,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( protobuf_internal::UnwrapStringValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateReferentStringValue( - handle_from_this(), absl::get(wrapped)); + return context.value_factory().CreateBorrowedStringValue( + owner_from_this(), absl::get(wrapped)); } else { ABSL_ASSERT(absl::holds_alternative(wrapped)); return context.value_factory().CreateStringValue( @@ -2327,8 +2327,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( } case Kind::kStruct: return context.value_factory() - .CreateReferentStructValue( - handle_from_this(), field.type.As(), + .CreateBorrowedStructValue( + owner_from_this(), field.type.As(), &(reflect.GetMessage(value(), &field_desc))); default: ABSL_UNREACHABLE(); @@ -2339,8 +2339,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateBytesValue( reflect.GetCord(value(), &field_desc)); } else { - return context.value_factory().CreateReferentBytesValue( - handle_from_this(), reflect.GetStringView(value(), &field_desc)); + return context.value_factory().CreateBorrowedBytesValue( + owner_from_this(), reflect.GetStringView(value(), &field_desc)); } case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type->kind()) { From d1b1c9810631a16c15b49831651ce6cb06cfa4ec Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 2 May 2023 05:24:50 +0000 Subject: [PATCH 234/303] Remove attribute trail tracking in contributions when unknowns are not enabled. PiperOrigin-RevId: 528674423 --- .../flat_expr_builder_comprehensions_test.cc | 37 +++++++++++++++++++ eval/eval/BUILD | 2 - eval/eval/comprehension_step.cc | 16 +++++--- eval/eval/comprehension_step.h | 1 - eval/eval/ident_step.cc | 25 +++++++------ 5 files changed, 62 insertions(+), 19 deletions(-) diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index aa63a5109..34b312630 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -91,6 +91,43 @@ TEST(FlatExprBuilderComprehensionsTest, MapComp) { test::EqualsCelValue(CelValue::CreateInt64(4))); } +TEST(FlatExprBuilderComprehensionsTest, ListCompWithUnknowns) { + cel::RuntimeOptions options; + options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; + FlatExprBuilder builder(options); + + ASSERT_OK_AND_ASSIGN(auto parsed_expr, + parser::Parse("items.exists(i, i < 0)")); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.set_unknown_attribute_patterns({CelAttributePattern{ + "items", + {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))}}}); + ContainerBackedListImpl list_impl = ContainerBackedListImpl({ + CelValue::CreateInt64(1), + // element items[1] is marked unknown, so the computation should produce + // and unknown set. + CelValue::CreateInt64(-1), + CelValue::CreateInt64(2), + }); + activation.InsertValue("items", CelValue::CreateList(&list_impl)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsUnknownSet()) << result.DebugString(); + + const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); + EXPECT_THAT(attrs, testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("items")); + EXPECT_THAT(attrs.begin()->qualifier_path(), testing::SizeIs(1)); + EXPECT_THAT(attrs.begin()->qualifier_path().at(0).GetInt64Key().value(), + testing::Eq(1)); +} + TEST(FlatExprBuilderComprehensionsTest, InvalidComprehensionWithRewrite) { CheckedExpr expr; // The rewrite step which occurs when an identifier gets a more qualified name diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 2fa4699b2..d3ac81b5a 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -167,7 +167,6 @@ cc_library( "//base:ast_internal", "//eval/internal:errors", "//eval/internal:interop", - "//eval/public:unknown_attribute_set", "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", @@ -349,7 +348,6 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index c9ed960bf..9410cbc42 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -93,7 +93,6 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } auto state = frame->value_stack().GetSpan(5); - auto attr = frame->value_stack().GetAttributeSpan(5); // Get range from the stack. auto iter_range = state[POS_ITER_RANGE]; @@ -109,7 +108,6 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { frame->memory_manager(), ""))); return frame->JumpTo(error_jump_offset_); } - AttributeTrail iter_range_attr = attr[POS_ITER_RANGE]; // Get the current index off the stack. const auto& current_index_value = state[POS_CURRENT_INDEX]; @@ -125,6 +123,16 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { CEL_RETURN_IF_ERROR(frame->PushIterFrame(iter_var_, accu_var_)); } + AttributeTrail iter_range_attr; + AttributeTrail iter_trail; + if (frame->enable_unknowns()) { + auto attr = frame->value_stack().GetAttributeSpan(5); + iter_range_attr = attr[POS_ITER_RANGE]; + iter_trail = + iter_range_attr.Step(cel::AttributeQualifier::OfInt(current_index + 1), + frame->memory_manager()); + } + // Update stack for breaking out of loop or next round. auto loop_step = state[POS_LOOP_STEP]; frame->value_stack().Pop(5); @@ -135,7 +143,7 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { CEL_RETURN_IF_ERROR(frame->ClearIterVar()); return frame->JumpTo(jump_offset_); } - frame->value_stack().Push(iter_range, iter_range_attr); + frame->value_stack().Push(iter_range, std::move(iter_range_attr)); current_index += 1; CEL_ASSIGN_OR_RETURN(auto current_value, @@ -144,8 +152,6 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { static_cast(current_index))); frame->value_stack().Push( cel::interop_internal::CreateIntValue(current_index)); - AttributeTrail iter_trail = iter_range_attr.Step( - cel::AttributeQualifier::OfInt(current_index), frame->memory_manager()); frame->value_stack().Push(current_value, iter_trail); CEL_RETURN_IF_ERROR(frame->SetIterVar(current_value, std::move(iter_trail))); return absl::OkStatus(); diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index dec26e8a4..f0b7a9ff5 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -5,7 +5,6 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 853da4a8c..4ce459278 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -53,21 +53,21 @@ absl::StatusOr IdentStep::DoEvaluate( ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); // Special case - comprehension variables mask any activation vars. - if (frame->GetIterVar(name_, &result.value, &result.trail)) { - return result; - } + bool iter_var = frame->GetIterVar(name_, &result.value, &result.trail); // Populate trails if either MissingAttributeError or UnknownPattern // is enabled. - if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { - result.trail = AttributeTrail(name_); - } + if (!iter_var) { + if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { + result.trail = AttributeTrail(name_); + } - if (frame->enable_missing_attribute_errors() && !name_.empty() && - frame->attribute_utility().CheckForMissingAttribute(result.trail)) { - result.value = cel::interop_internal::CreateErrorValueFromView( - CreateMissingAttributeError(frame->memory_manager(), name_)); - return result; + if (frame->enable_missing_attribute_errors() && !name_.empty() && + frame->attribute_utility().CheckForMissingAttribute(result.trail)) { + result.value = cel::interop_internal::CreateErrorValueFromView( + CreateMissingAttributeError(frame->memory_manager(), name_)); + return result; + } } if (frame->enable_unknowns()) { @@ -78,6 +78,9 @@ absl::StatusOr IdentStep::DoEvaluate( return result; } } + if (iter_var) { + return result; + } CEL_ASSIGN_OR_RETURN(auto value, frame->modern_activation().FindVariable( frame->value_factory(), name_)); From b59a586502eb832100dcbbeca3ca8b9c500ed8d4 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 2 May 2023 13:29:34 +0000 Subject: [PATCH 235/303] Eagerly cache JSON-like types and some minor cleanups PiperOrigin-RevId: 528764989 --- base/type_factory.cc | 15 +++++++++++++++ base/type_factory.h | 18 ++++++++++++++++-- base/type_factory_test.cc | 20 ++++++++++++++++++++ base/value_factory.cc | 23 ++--------------------- base/value_factory.h | 19 +++++++++++++++---- 5 files changed, 68 insertions(+), 27 deletions(-) diff --git a/base/type_factory.cc b/base/type_factory.cc index 97d9a44da..5f805a3c8 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -30,6 +30,21 @@ using base_internal::ModernMapType; } // namespace +TypeFactory::TypeFactory(MemoryManager& memory_manager) + : memory_manager_(memory_manager) { + json_list_type_ = list_types_ + .insert({GetJsonValueType(), + HandleFactory::Make( + memory_manager_, GetJsonValueType())}) + .first->second; + json_map_type_ = + map_types_ + .insert({std::make_pair(GetStringType(), GetJsonValueType()), + HandleFactory::Make( + memory_manager_, GetStringType(), GetJsonValueType())}) + .first->second; +} + absl::StatusOr> TypeFactory::CreateListType( const Handle& element) { ABSL_DCHECK(element) << "handle must not be empty"; diff --git a/base/type_factory.h b/base/type_factory.h index 9923054a0..5406ada87 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -57,8 +57,7 @@ class TypeFactory final { public: explicit TypeFactory( - MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) - : memory_manager_(memory_manager) {} + MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND); TypeFactory(const TypeFactory&) = delete; TypeFactory(TypeFactory&&) = delete; @@ -152,6 +151,18 @@ class TypeFactory final { return UintWrapperType::Get(); } + const Handle& GetJsonValueType() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetDynType().As(); + } + + const Handle& GetJsonListType() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return json_list_type_; + } + + const Handle& GetJsonMapType() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return json_map_type_; + } + template EnableIfBaseOfT>> CreateEnumType( Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { @@ -181,6 +192,9 @@ class TypeFactory final { private: MemoryManager& memory_manager_; + Handle json_list_type_; + Handle json_map_type_; + absl::Mutex list_types_mutex_; // Mapping from list element types to the list type. This allows us to cache // list types and avoid re-creating the same type. diff --git a/base/type_factory_test.cc b/base/type_factory_test.cc index 1dc80d797..4f8e1a3d3 100644 --- a/base/type_factory_test.cc +++ b/base/type_factory_test.cc @@ -41,5 +41,25 @@ TEST(TypeFactory, CreateMapTypeCaches) { EXPECT_EQ(map_type_1.operator->(), map_type_2.operator->()); } +TEST(TypeFactory, JsonValueType) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetJsonValueType(), type_factory.GetDynType()); +} + +TEST(TypeFactory, JsonListType) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto type, + type_factory.CreateListType(type_factory.GetDynType())); + EXPECT_EQ(type, type_factory.GetJsonListType()); +} + +TEST(TypeFactory, JsonMapType) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetDynType())); + EXPECT_EQ(type, type_factory.GetJsonMapType()); +} + } // namespace } // namespace cel diff --git a/base/value_factory.cc b/base/value_factory.cc index fd7bf478d..c05c7dd51 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -36,7 +36,6 @@ using base_internal::InlinedCordBytesValue; using base_internal::InlinedCordStringValue; using base_internal::InlinedStringViewBytesValue; using base_internal::InlinedStringViewStringValue; -using base_internal::ModernTypeValue; using base_internal::StringBytesValue; using base_internal::StringStringValue; @@ -46,10 +45,6 @@ Handle NullValue::Get(ValueFactory& value_factory) { return value_factory.GetNullValue(); } -Handle ValueFactory::GetNullValue() { - return HandleFactory::Make(); -} - Handle ValueFactory::CreateErrorValue(absl::Status status) { if (ABSL_PREDICT_FALSE(status.ok())) { status = absl::UnknownError( @@ -180,17 +175,13 @@ absl::StatusOr> ValueFactory::CreateStringValue( absl::StatusOr> ValueFactory::CreateDurationValue( absl::Duration value) { CEL_RETURN_IF_ERROR(internal::ValidateDuration(value)); - return HandleFactory::Make(value); + return CreateUncheckedDurationValue(value); } absl::StatusOr> ValueFactory::CreateTimestampValue( absl::Time value) { CEL_RETURN_IF_ERROR(internal::ValidateTimestamp(value)); - return HandleFactory::Make(value); -} - -Handle ValueFactory::CreateTypeValue(const Handle& value) { - return HandleFactory::Make(value); + return CreateUncheckedTimestampValue(value); } Handle ValueFactory::CreateUnknownValue( @@ -205,16 +196,6 @@ absl::StatusOr> ValueFactory::CreateBytesValueFromView( return HandleFactory::Make(value); } -Handle ValueFactory::GetEmptyBytesValue() { - return HandleFactory::Make( - absl::string_view()); -} - -Handle ValueFactory::GetEmptyStringValue() { - return HandleFactory::Make( - absl::string_view()); -} - absl::StatusOr> ValueFactory::CreateStringValueFromView( absl::string_view value) { return HandleFactory::Make(value); diff --git a/base/value_factory.h b/base/value_factory.h index 5f213bbe0..2f6c0d9f0 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -107,7 +107,9 @@ class ValueFactory final { TypeManager& type_manager() const { return type_manager_; } - Handle GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; + Handle GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::HandleFactory::Make(); + } Handle CreateErrorValue(absl::Status status) ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -410,7 +412,10 @@ class ValueFactory final { } Handle CreateTypeValue(const Handle& value) - ABSL_ATTRIBUTE_LIFETIME_BOUND; + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::HandleFactory::Make< + base_internal::ModernTypeValue>(value); + } Handle CreateUnknownValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return CreateUnknownValue(AttributeSet(), FunctionResultSet()); @@ -438,12 +443,18 @@ class ValueFactory final { friend class BytesValue; friend class StringValue; - Handle GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; + Handle GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::HandleFactory::Make< + base_internal::InlinedStringViewBytesValue>(absl::string_view()); + } absl::StatusOr> CreateBytesValueFromView( absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - Handle GetEmptyStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; + Handle GetEmptyStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::HandleFactory::Make< + base_internal::InlinedStringViewStringValue>(absl::string_view()); + } absl::StatusOr> CreateStringValueFromView( absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; From 1dcf06e64048684a8cf44980921e9e75da58c79c Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 2 May 2023 16:10:35 +0000 Subject: [PATCH 236/303] Internal tooling change PiperOrigin-RevId: 528801107 --- eval/public/ast_rewrite.cc | 2 + eval/public/ast_rewrite.h | 6 + .../macro_nested_macro_call.textproto | 257 ++++++++++++++++++ 3 files changed, 265 insertions(+) create mode 100644 tools/testdata/macro_nested_macro_call.textproto diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 2553e7b88..c509a3a80 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -193,6 +193,8 @@ struct PostVisitor { visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, &position); break; + case Expr::EXPR_KIND_NOT_SET: + break; default: LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } diff --git a/eval/public/ast_rewrite.h b/eval/public/ast_rewrite.h index 9c4eacd8c..c21cb86bc 100644 --- a/eval/public/ast_rewrite.h +++ b/eval/public/ast_rewrite.h @@ -60,6 +60,12 @@ class AstRewriterBase : public AstRewriter { public: ~AstRewriterBase() override {} + void PreVisitExpr(const google::api::expr::v1alpha1::Expr*, + const SourcePosition*) override {} + + void PostVisitExpr(const google::api::expr::v1alpha1::Expr*, + const SourcePosition*) override {} + void PostVisitConst(const google::api::expr::v1alpha1::Constant*, const google::api::expr::v1alpha1::Expr*, const SourcePosition*) override {} diff --git a/tools/testdata/macro_nested_macro_call.textproto b/tools/testdata/macro_nested_macro_call.textproto new file mode 100644 index 000000000..11bdf7f6f --- /dev/null +++ b/tools/testdata/macro_nested_macro_call.textproto @@ -0,0 +1,257 @@ +# proto-file: google3/google/api/expr/checked.proto +# proto-message: CheckedExpr +# math.least(has(msg.old_field) ? msg.old_field : 0, 1) +reference_map: { + key: 4 + value: { + name: "msg" + } +} +reference_map: { + key: 7 + value: { + overload_id: "conditional" + } +} +reference_map: { + key: 8 + value: { + name: "msg" + } +} +reference_map: { + key: 12 + value: { + overload_id: "math_@min_int_int" + } +} +type_map: { + key: 4 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 6 + value: { + primitive: BOOL + } +} +type_map: { + key: 7 + value: { + primitive: INT64 + } +} +type_map: { + key: 8 + value: { + map_type: { + key_type: { + primitive: STRING + } + value_type: { + primitive: INT64 + } + } + } +} +type_map: { + key: 9 + value: { + primitive: INT64 + } +} +type_map: { + key: 10 + value: { + primitive: INT64 + } +} +type_map: { + key: 11 + value: { + primitive: INT64 + } +} +type_map: { + key: 12 + value: { + primitive: INT64 + } +} +source_info: { + location: "" + line_offsets: 54 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 10 + } + positions: { + key: 3 + value: 14 + } + positions: { + key: 4 + value: 15 + } + positions: { + key: 5 + value: 18 + } + positions: { + key: 6 + value: 14 + } + positions: { + key: 7 + value: 30 + } + positions: { + key: 8 + value: 32 + } + positions: { + key: 9 + value: 35 + } + positions: { + key: 10 + value: 48 + } + positions: { + key: 11 + value: 51 + } + positions: { + key: 12 + value: 10 + } + macro_calls: { + key: 6 + value: { + call_expr: { + function: "has" + args: { + id: 5 + select_expr: { + operand: { + id: 4 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + } + } + } + macro_calls: { + key: 12 + value: { + call_expr: { + target: { + id: 1 + ident_expr: { + name: "math" + } + } + function: "least" + args: { + id: 7 + call_expr: { + function: "_?_:_" + args: { + id: 6 + } + args: { + id: 9 + select_expr: { + operand: { + id: 8 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 10 + const_expr: { + int64_value: 0 + } + } + } + } + args: { + id: 11 + const_expr: { + int64_value: 1 + } + } + } + } + } +} +expr: { + id: 12 + call_expr: { + function: "math.@min" + args: { + id: 7 + call_expr: { + function: "_?_:_" + args: { + id: 6 + select_expr: { + operand: { + id: 4 + ident_expr: { + name: "msg" + } + } + field: "old_field" + test_only: true + } + } + args: { + id: 9 + select_expr: { + operand: { + id: 8 + ident_expr: { + name: "msg" + } + } + field: "old_field" + } + } + args: { + id: 10 + const_expr: { + int64_value: 0 + } + } + } + } + args: { + id: 11 + const_expr: { + int64_value: 1 + } + } + } +} From cd59c9cda351789b9ccc3f259bb0f4d6962d3e34 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Tue, 2 May 2023 18:08:08 +0000 Subject: [PATCH 237/303] Refactor core evaluation step to a member function of an Execution frame instance. This allows for reuse in contexts where the evaluator doesn't need all of the state management required for FlatExpression instances. PiperOrigin-RevId: 528834986 --- eval/eval/BUILD | 3 ++ eval/eval/evaluator_core.cc | 70 ++++++++++++++++++++----------------- eval/eval/evaluator_core.h | 6 ++++ 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index d3ac81b5a..4880a208a 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -19,6 +19,7 @@ cc_library( ":attribute_utility", ":evaluator_stack", "//base:ast_internal", + "//base:handle", "//base:memory_manager", "//base:type", "//base:value", @@ -32,9 +33,11 @@ cc_library( "//eval/public:unknown_attribute_set", "//extensions/protobuf:memory_manager", "//internal:casts", + "//internal:status_macros", "//runtime:activation_interface", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index af6a341be..a35986c3c 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -5,6 +5,7 @@ #include #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -13,9 +14,11 @@ #include "base/value_factory.h" #include "eval/eval/attribute_trail.h" #include "eval/internal/interop.h" +#include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { @@ -148,53 +151,54 @@ absl::StatusOr CelExpressionFlatImpl::Evaluate( return Trace(activation, state, CelEvaluationListener()); } -absl::StatusOr CelExpressionFlatImpl::Trace( - const BaseActivation& activation, CelEvaluationState* _state, - CelEvaluationListener callback) const { - auto state = - ::cel::internal::down_cast(_state); - state->Reset(); - - ExecutionFrame frame(path_, activation, &type_registry_, options_, state); - - EvaluatorStack* stack = &frame.value_stack(); - size_t initial_stack_size = stack->size(); +absl::StatusOr> ExecutionFrame::Evaluate( + const CelEvaluationListener& listener) { + size_t initial_stack_size = value_stack().size(); const ExpressionStep* expr; - while ((expr = frame.Next()) != nullptr) { - auto status = expr->Evaluate(&frame); - if (!status.ok()) { - return status; - } - if (!callback) { - continue; - } - if (!expr->ComesFromAst()) { - // This step was added during compilation (e.g. Int64ConstImpl). + google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( + value_factory().memory_manager()); + while ((expr = Next()) != nullptr) { + CEL_RETURN_IF_ERROR(expr->Evaluate(this)); + + if (!listener || + // This step was added during compilation (e.g. Int64ConstImpl). + !expr->ComesFromAst()) { continue; } - if (stack->empty()) { + if (value_stack().empty()) { LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " "Try to disable short-circuiting."; continue; } - auto status2 = - callback(expr->id(), + CEL_RETURN_IF_ERROR( + listener(expr->id(), cel::interop_internal::ModernValueToLegacyValueOrDie( - state->arena(), stack->Peek()), - state->arena()); - if (!status2.ok()) { - return status2; - } + arena, value_stack().Peek()), + arena)); } - size_t final_stack_size = stack->size(); - if (initial_stack_size + 1 != final_stack_size || final_stack_size == 0) { + size_t final_stack_size = value_stack().size(); + if (final_stack_size != initial_stack_size + 1 || final_stack_size == 0) { return absl::Status(absl::StatusCode::kInternal, "Stack error during evaluation"); } - auto value = stack->Peek(); - stack->Pop(1); + cel::Handle value = value_stack().Peek(); + value_stack().Pop(1); + return value; +} + +absl::StatusOr CelExpressionFlatImpl::Trace( + const BaseActivation& activation, CelEvaluationState* _state, + CelEvaluationListener callback) const { + auto state = + ::cel::internal::down_cast(_state); + state->Reset(); + + ExecutionFrame frame(path_, activation, &type_registry_, options_, state); + + CEL_ASSIGN_OR_RETURN(cel::Handle value, frame.Evaluate(callback)); + return cel::interop_internal::ModernValueToLegacyValueOrDie(state->arena(), value); } diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index f5f94d669..43e812f2d 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -22,8 +22,10 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/ast_internal.h" +#include "base/handle.h" #include "base/memory_manager.h" #include "base/type_manager.h" +#include "base/value.h" #include "base/value_factory.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/attribute_utility.h" @@ -152,6 +154,10 @@ class ExecutionFrame { // Returns next expression to evaluate. const ExpressionStep* Next(); + // Evaluate the execution frame to completion. + absl::StatusOr> Evaluate( + const CelEvaluationListener& listener); + // Intended for use only in conditionals. absl::Status JumpTo(int offset) { int new_pc = static_cast(pc_) + offset; From 4577ead9c7fef6b97d110f9086f262db4f2fcab4 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Tue, 2 May 2023 21:16:54 +0000 Subject: [PATCH 238/303] Add flat expr builder extension API to support program updates while planning. Introduce updated const folding implementation that uses the new api. Add conformance tests for the two variants of const folding. PiperOrigin-RevId: 528886562 --- conformance/BUILD | 3 +- conformance/server.cc | 18 +- eval/compiler/BUILD | 55 ++++ eval/compiler/constant_folding.cc | 144 +++++++++- eval/compiler/constant_folding.h | 33 ++- eval/compiler/constant_folding_test.cc | 246 ++++++++++++++++ eval/compiler/flat_expr_builder.cc | 84 +++++- eval/compiler/flat_expr_builder.h | 5 +- eval/compiler/flat_expr_builder_extensions.cc | 115 ++++++++ eval/compiler/flat_expr_builder_extensions.h | 43 ++- .../flat_expr_builder_extensions_test.cc | 266 ++++++++++++++++++ eval/compiler/flat_expr_builder_test.cc | 193 ++++++++++++- eval/public/cel_options.h | 1 + .../portable_cel_expr_builder_factory.cc | 3 +- eval/tests/BUILD | 34 ++- eval/tests/benchmark_test.cc | 141 +++++++--- .../expression_builder_benchmark_test.cc | 94 +++---- 17 files changed, 1337 insertions(+), 141 deletions(-) create mode 100644 eval/compiler/flat_expr_builder_extensions.cc create mode 100644 eval/compiler/flat_expr_builder_extensions_test.cc diff --git a/conformance/BUILD b/conformance/BUILD index 33c9b7133..fdd29bc66 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -67,7 +67,7 @@ cc_binary( [ sh_test( - name = "simple" + arg, + name = "simple" + arg.replace("--", "_"), srcs = ["@com_google_cel_spec//tests:conftest.sh"], args = [ "$(location @com_google_cel_spec//tests/simple:simple_test)", @@ -105,6 +105,7 @@ cc_binary( for arg in [ "", "--opt", + "--updated_opt", ] ] diff --git a/conformance/server.cc b/conformance/server.cc index 969d9bc4f..3a1f67b8f 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -35,6 +35,8 @@ using ::google::protobuf::util::JsonStringToMessage; using ::google::protobuf::util::MessageToJsonString; ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)"); +ABSL_FLAG(bool, updated_opt, false, + "Enable optimizations (constant folding updated)"); ABSL_FLAG(bool, base64_encode, false, "Enable base64 encoding in pipe mode."); namespace google::api::expr::runtime { @@ -150,10 +152,10 @@ class ConformanceServiceImpl { const google::api::expr::test::v1::proto3::TestAllTypes* proto3_tests_; }; -absl::Status Base64DecodeToMessage(absl::string_view b64Data, +absl::Status Base64DecodeToMessage(absl::string_view b64_data, google::protobuf::Message* out) { std::string data; - if (!absl::Base64Unescape(b64Data, &data)) { + if (!absl::Base64Unescape(b64_data, &data)) { return absl::InvalidArgumentError("invalid base64"); } if (!out->ParseFromString(data)) { @@ -197,20 +199,23 @@ class PipeCodec { bool base64_encoded_; }; -int RunServer(bool optimize, bool base64Encoded) { +int RunServer(bool optimize, bool base64_encoded, bool updated_optimize) { google::protobuf::Arena arena; - PipeCodec pipe_codec(base64Encoded); + PipeCodec pipe_codec(base64_encoded); InterpreterOptions options; options.enable_qualified_type_identifiers = true; options.enable_timestamp_duration_overflow_errors = true; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; - if (optimize) { + if (optimize || updated_optimize) { std::cerr << "Enabling optimizations" << std::endl; options.constant_folding = true; options.constant_arena = &arena; } + if (updated_optimize) { + options.enable_updated_constant_folding = true; + } std::unique_ptr builder = CreateCelExpressionBuilder(options); @@ -282,5 +287,6 @@ int RunServer(bool optimize, bool base64Encoded) { int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); return google::api::expr::runtime::RunServer( - absl::GetFlag(FLAGS_opt), absl::GetFlag(FLAGS_base64_encode)); + absl::GetFlag(FLAGS_opt), absl::GetFlag(FLAGS_base64_encode), + absl::GetFlag(FLAGS_updated_opt)); } diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index b81969900..61467f95b 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -8,11 +8,38 @@ exports_files(["LICENSE"]) cc_library( name = "flat_expr_builder_extensions", + srcs = ["flat_expr_builder_extensions.cc"], hdrs = ["flat_expr_builder_extensions.h"], deps = [ ":resolver", "//base:ast", + "//base:ast_internal", + "//eval/eval:evaluator_core", + "//eval/eval:expression_build_warning", + "//eval/public:cel_type_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "flat_expr_builder_extensions_test", + srcs = ["flat_expr_builder_extensions_test.cc"], + deps = [ + ":flat_expr_builder_extensions", + ":resolver", + "//base:ast_internal", + "//eval/eval:const_value_step", + "//eval/eval:evaluator_core", "//eval/eval:expression_build_warning", + "//eval/public:cel_type_registry", + "//internal:testing", + "//runtime:function_registry", + "//runtime:runtime_options", "@com_google_absl//absl/status", ], ) @@ -84,6 +111,8 @@ cc_test( deps = [ ":flat_expr_builder", ":qualified_reference_resolver", + "//base:function", + "//base:function_descriptor", "//eval/eval:expression_build_warning", "//eval/public:activation", "//eval/public:builtin_func_registrar", @@ -92,8 +121,10 @@ cc_test( "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_adapter", + "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", + "//eval/public:portable_cel_function_adapter", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//eval/public/containers:container_backed_map_impl", @@ -106,6 +137,7 @@ cc_test( "//internal:testing", "//parser", "//runtime:runtime_options", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -155,12 +187,20 @@ cc_library( "constant_folding.h", ], deps = [ + ":flat_expr_builder_extensions", + ":resolver", "//base:ast_internal", "//base:function", + "//base:handle", + "//base:kind", "//base:value", + "//eval/eval:const_value_step", + "//eval/eval:evaluator_core", "//eval/internal:errors", "//eval/internal:interop", + "//eval/public:activation", "//eval/public:cel_builtins", + "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//extensions/protobuf:memory_manager", @@ -168,8 +208,12 @@ cc_library( "//runtime:function_overload_reference", "//runtime:function_registry", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -180,15 +224,26 @@ cc_test( ], deps = [ ":constant_folding", + ":flat_expr_builder_extensions", + ":resolver", + "//base:ast_internal", "//base:type", "//base:value", + "//base/internal:ast_impl", + "//eval/eval:const_value_step", + "//eval/eval:evaluator_core", + "//eval/eval:expression_build_warning", "//eval/public:builtin_func_registrar", "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:ast_converters", "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//parser", + "//runtime:function_registry", + "//runtime:runtime_options", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 0a0499dd5..2f30084c1 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -5,13 +5,28 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "base/ast_internal.h" #include "base/function.h" +#include "base/handle.h" +#include "base/kind.h" +#include "base/value.h" +#include "base/values/bytes_value.h" #include "base/values/error_value.h" +#include "base/values/string_value.h" +#include "base/values/unknown_value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/internal/errors.h" #include "eval/internal/interop.h" +#include "eval/public/activation.h" #include "eval/public/cel_builtins.h" +#include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "extensions/protobuf/memory_manager.h" @@ -27,8 +42,18 @@ using ::cel::interop_internal::CreateErrorValueFromView; using ::cel::interop_internal::CreateLegacyListValue; using ::cel::interop_internal::CreateNoMatchingOverloadError; using ::cel::interop_internal::ModernValueToLegacyValueOrDie; +using ::google::api::expr::runtime::CelEvaluationListener; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; +using ::google::api::expr::runtime::ExecutionFrame; +using ::google::api::expr::runtime::ExecutionPath; +using ::google::api::expr::runtime::ExecutionPathView; +using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::Resolver; +using ::google::api::expr::runtime::builtin::kAnd; +using ::google::api::expr::runtime::builtin::kOr; +using ::google::api::expr::runtime::builtin::kTernary; + using ::google::protobuf::Arena; Handle CreateLegacyListBackedHandle( @@ -48,34 +73,33 @@ struct MakeConstantArenaSafeVisitor { // non-arena based cel::MemoryManager. google::protobuf::Arena* arena; - cel::Handle operator()( - const cel::ast::internal::NullValue& value) { + Handle operator()(const cel::ast::internal::NullValue& value) { return cel::interop_internal::CreateNullValue(); } - cel::Handle operator()(bool value) { + Handle operator()(bool value) { return cel::interop_internal::CreateBoolValue(value); } - cel::Handle operator()(int64_t value) { + Handle operator()(int64_t value) { return cel::interop_internal::CreateIntValue(value); } - cel::Handle operator()(uint64_t value) { + Handle operator()(uint64_t value) { return cel::interop_internal::CreateUintValue(value); } - cel::Handle operator()(double value) { + Handle operator()(double value) { return cel::interop_internal::CreateDoubleValue(value); } - cel::Handle operator()(const std::string& value) { + Handle operator()(const std::string& value) { const auto* arena_copy = Arena::Create(arena, value); return cel::interop_internal::CreateStringValueFromView(*arena_copy); } - cel::Handle operator()(const cel::ast::internal::Bytes& value) { + Handle operator()(const cel::ast::internal::Bytes& value) { const auto* arena_copy = Arena::Create(arena, value.bytes); return cel::interop_internal::CreateBytesValueFromView(*arena_copy); } - cel::Handle operator()(const absl::Duration duration) { + Handle operator()(const absl::Duration duration) { return cel::interop_internal::CreateDurationValue(duration); } - cel::Handle operator()(const absl::Time timestamp) { + Handle operator()(const absl::Time timestamp) { return cel::interop_internal::CreateTimestampValue(timestamp); } }; @@ -137,6 +161,10 @@ class ConstantFoldingTransform { } bool operator()(const Ident& ident) { + // TODO(issues/5): this could be updated to use the rewrite visitor + // to make changes in-place instead of manually copy. This would avoid + // having to understand how to copy all of the information in the original + // AST. out_.mutable_ident_expr().set_name(expr_.ident_expr().name()); return false; } @@ -249,6 +277,7 @@ class ConstantFoldingTransform { bool all_constant = true; for (int i = 0; i < list_size; i++) { auto& element = list_expr.mutable_elements().emplace_back(); + // TODO(issues/5): Add support for CEL optional. all_constant = transform_.Transform(expr_.list_expr().elements()[i], element) && all_constant; @@ -287,6 +316,7 @@ class ConstantFoldingTransform { auto& new_entry = struct_expr.mutable_entries().emplace_back(); new_entry.set_id(entry.id()); struct { + // TODO(issues/5): Add support for CEL optional. ConstantFoldingTransform& transform; const CreateStruct::Entry& entry; CreateStruct::Entry& new_entry; @@ -358,6 +388,100 @@ bool ConstantFoldingTransform::Transform(const Expr& expr, Expr& out_) { } // namespace +absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, + const Expr& node) { + struct IsConstVisitor { + IsConst operator()(const Constant&) { return IsConst::kConditional; } + IsConst operator()(const Ident&) { return IsConst::kNonConst; } + IsConst operator()(const Comprehension&) { + // Not yet supported, need to identify whether range and + // iter vars are compatible with const folding. + return IsConst::kNonConst; + } + IsConst operator()(const CreateStruct&) { + // Not yet supported but should be possible in the future. + return IsConst::kNonConst; + } + IsConst operator()(const CreateList& create_list) { + if (create_list.elements().empty()) { + // TODO(issues/5): Don't fold for empty list to allow comprehension + // list append optimization. + return IsConst::kNonConst; + } + return IsConst::kConditional; + } + + IsConst operator()(const Select&) { return IsConst::kConditional; } + + IsConst operator()(absl::monostate) { return IsConst::kNonConst; } + + IsConst operator()(const Call& call) { + // Shortcircuiting operators not yet supported. + if (call.function() == kAnd || call.function() == kOr || + call.function() == kTernary) { + return IsConst::kNonConst; + } + + int arg_len = call.args().size() + (call.has_target() ? 1 : 0); + std::vector arg_matcher(arg_len, cel::Kind::kAny); + // Check for any lazy overloads (activation dependant) + if (!resolver + .FindLazyOverloads(call.function(), call.has_target(), + arg_matcher) + .empty()) { + return IsConst::kNonConst; + } + + return IsConst::kConditional; + } + + const Resolver& resolver; + }; + + IsConst is_const = + absl::visit(IsConstVisitor{context.resolver()}, node.expr_kind()); + is_const_.push_back(is_const); + + return absl::OkStatus(); +} + +absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, + const Expr& node) { + IsConst is_const = is_const_.back(); + is_const_.pop_back(); + + if (is_const == IsConst::kNonConst) { + // update parent + if (!is_const_.empty()) { + is_const_.back() = IsConst::kNonConst; + } + return absl::OkStatus(); + } + + // copy string to arena if backed by the original program. + Handle value; + if (node.has_const_expr()) { + value = absl::visit(MakeConstantArenaSafeVisitor{arena_}, + node.const_expr().constant_kind()); + } else { + ExecutionPathView subplan = context.GetSubplan(node); + ExecutionFrame frame(subplan, empty_, &context.type_registry(), + context.options(), &state_); + state_.Reset(); + CEL_ASSIGN_OR_RETURN(value, frame.Evaluate(null_listener_)); + if (value->Is()) { + return absl::OkStatus(); + } + } + + ExecutionPath new_plan; + CEL_ASSIGN_OR_RETURN(new_plan.emplace_back(), + google::api::expr::runtime::CreateConstValueStep( + std::move(value), node.id(), false)); + + return context.ReplaceSubplan(node, std::move(new_plan)); +} + void FoldConstants( const Expr& ast, const FunctionRegistry& registry, google::protobuf::Arena* arena, absl::flat_hash_map>& constant_idents, diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index bc1010b8a..65dbbf224 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -2,11 +2,18 @@ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #include +#include -#include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "base/ast_internal.h" +#include "base/value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" #include "runtime/function_registry.h" +#include "google/protobuf/arena.h" namespace cel::ast::internal { @@ -18,6 +25,30 @@ void FoldConstants( absl::flat_hash_map>& constant_idents, Expr& out_ast); +class ConstantFoldingExtension { + public: + ConstantFoldingExtension(int stack_limit, google::protobuf::Arena* arena) + : arena_(arena), state_(stack_limit, arena) {} + + absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node); + absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node); + + private: + enum class IsConst { + kConditional, + kNonConst, + }; + + google::protobuf::Arena* arena_; + google::api::expr::runtime::Activation empty_; + google::api::expr::runtime::CelEvaluationListener null_listener_; + google::api::expr::runtime::CelExpressionFlatEvaluationState state_; + + std::vector is_const_; +}; + } // namespace cel::ast::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index c04651653..c03d6abd9 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -1,9 +1,12 @@ #include "eval/compiler/constant_folding.h" +#include #include #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/text_format.h" +#include "base/ast_internal.h" +#include "base/internal/ast_impl.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/value_factory.h" @@ -12,22 +15,43 @@ #include "base/values/int_value.h" #include "base/values/list_value.h" #include "base/values/string_value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/ast_converters.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/parser.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/text_format.h" namespace cel::ast::internal { namespace { +using ::cel::ast::internal::Constant; +using ::cel::ast::internal::ConstantKind; using ::cel::extensions::ProtoMemoryManager; using ::cel::extensions::internal::ConvertProtoExprToNative; +using ::google::api::expr::v1alpha1::ParsedExpr; +using ::google::api::expr::parser::Parse; +using ::google::api::expr::runtime::BuilderWarnings; using ::google::api::expr::runtime::CelFunctionRegistry; +using ::google::api::expr::runtime::CelTypeRegistry; +using ::google::api::expr::runtime::CreateConstValueStep; +using ::google::api::expr::runtime::ExecutionPath; +using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::Resolver; using ::google::protobuf::Arena; +using testing::SizeIs; class ConstantFoldingTestWithValueFactory : public testing::Test { public: @@ -502,6 +526,228 @@ TEST(ConstantFoldingTest, MapComprehension) { EXPECT_TRUE(idents["$v5"]->Is()); } +class UpdatedConstantFoldingTest : public testing::Test { + public: + UpdatedConstantFoldingTest() + : resolver_("", function_registry_, &type_registry_) {} + + protected: + cel::FunctionRegistry function_registry_; + CelTypeRegistry type_registry_; + cel::RuntimeOptions options_; + BuilderWarnings builder_warnings_; + Resolver resolver_; +}; + +absl::StatusOr> ParseFromCel( + absl::string_view expression) { + CEL_ASSIGN_OR_RETURN(ParsedExpr expr, Parse(expression)); + return cel::extensions::CreateAstFromParsedExpr(expr); +} + +// While CEL doesn't provide execution order guarantees per se, short circuiting +// operators are treated specially to evaluate to user expectations. +// +// These behaviors aren't easily observable since the flat expression doesn't +// expose any details about the program after building, so a lot of setup is +// needed to simulate what the expression builder does. +TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true ? true : false")); + AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); + + const Expr& call = ast_impl.root_expr(); + const Expr& condition = call.call_expr().args()[0]; + const Expr& true_branch = call.call_expr().args()[1]; + const Expr& false_branch = call.call_expr().args()[2]; + + PlannerContext::ProgramTree tree; + PlannerContext::ProgramInfo& call_info = tree[&call]; + call_info.range_start = 0; + call_info.range_len = 4; + call_info.children = {&condition, &true_branch, &false_branch}; + + PlannerContext::ProgramInfo& condition_info = tree[&condition]; + condition_info.range_start = 0; + condition_info.range_len = 1; + condition_info.parent = &call; + + PlannerContext::ProgramInfo& true_branch_info = tree[&true_branch]; + true_branch_info.range_start = 1; + true_branch_info.range_len = 1; + true_branch_info.parent = &call; + + PlannerContext::ProgramInfo& false_branch_info = tree[&false_branch]; + false_branch_info.range_start = 2; + false_branch_info.range_len = 1; + false_branch_info.parent = &call; + + // Mock execution path that has placeholders for the non-shortcircuiting + // version of ternary. + ExecutionPath path; + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(true)), -1)); + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(true)), -1)); + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(false)), -1)); + + // Just a placeholder. + ASSERT_OK_AND_ASSIGN( + path.emplace_back(), + CreateConstValueStep(Constant(NullValue::kNullValue), -1)); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + google::protobuf::Arena arena; + constexpr int kStackLimit = 1; + ConstantFoldingExtension constant_folder(kStackLimit, &arena); + + // Act + // Issue the visitation calls. + ASSERT_OK(constant_folder.OnPreVisit(context, call)); + ASSERT_OK(constant_folder.OnPreVisit(context, condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, condition)); + ASSERT_OK(constant_folder.OnPreVisit(context, true_branch)); + ASSERT_OK(constant_folder.OnPostVisit(context, true_branch)); + ASSERT_OK(constant_folder.OnPreVisit(context, false_branch)); + ASSERT_OK(constant_folder.OnPostVisit(context, false_branch)); + ASSERT_OK(constant_folder.OnPostVisit(context, call)); + + // Assert + // No changes attempted. + EXPECT_THAT(path, SizeIs(4)); +} + +TEST_F(UpdatedConstantFoldingTest, SkipsOr) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("false || true")); + AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); + + const Expr& call = ast_impl.root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + PlannerContext::ProgramTree tree; + PlannerContext::ProgramInfo& call_info = tree[&call]; + call_info.range_start = 0; + call_info.range_len = 4; + call_info.children = {&left_condition, &right_condition}; + + PlannerContext::ProgramInfo& left_condition_info = tree[&left_condition]; + left_condition_info.range_start = 0; + left_condition_info.range_len = 1; + left_condition_info.parent = &call; + + PlannerContext::ProgramInfo& right_condition_info = tree[&right_condition]; + right_condition_info.range_start = 1; + right_condition_info.range_len = 1; + right_condition_info.parent = &call; + + // Mock execution path that has placeholders for the non-shortcircuiting + // version of ternary. + ExecutionPath path; + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(false)), -1)); + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(true)), -1)); + + // Just a placeholder. + ASSERT_OK_AND_ASSIGN( + path.emplace_back(), + CreateConstValueStep(Constant(NullValue::kNullValue), -1)); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + google::protobuf::Arena arena; + constexpr int kStackLimit = 1; + ConstantFoldingExtension constant_folder(kStackLimit, &arena); + + // Act + // Issue the visitation calls. + ASSERT_OK(constant_folder.OnPreVisit(context, call)); + ASSERT_OK(constant_folder.OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder.OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, call)); + + // Assert + // No changes attempted. + EXPECT_THAT(path, SizeIs(3)); +} + +TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true && false")); + AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); + + const Expr& call = ast_impl.root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + PlannerContext::ProgramTree tree; + PlannerContext::ProgramInfo& call_info = tree[&call]; + call_info.range_start = 0; + call_info.range_len = 4; + call_info.children = {&left_condition, &right_condition}; + + PlannerContext::ProgramInfo& left_condition_info = tree[&left_condition]; + left_condition_info.range_start = 0; + left_condition_info.range_len = 1; + left_condition_info.parent = &call; + + PlannerContext::ProgramInfo& right_condition_info = tree[&right_condition]; + right_condition_info.range_start = 1; + right_condition_info.range_len = 1; + right_condition_info.parent = &call; + + // Mock execution path that has placeholders for the non-shortcircuiting + // version of ternary. + ExecutionPath path; + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(true)), -1)); + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(false)), -1)); + + // Just a placeholder. + ASSERT_OK_AND_ASSIGN( + path.emplace_back(), + CreateConstValueStep(Constant(NullValue::kNullValue), -1)); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + google::protobuf::Arena arena; + constexpr int kStackLimit = 1; + ConstantFoldingExtension constant_folder(kStackLimit, &arena); + + // Act + // Issue the visitation calls. + ASSERT_OK(constant_folder.OnPreVisit(context, call)); + ASSERT_OK(constant_folder.OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder.OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder.OnPostVisit(context, call)); + + // Assert + // No changes attempted. + EXPECT_THAT(path, SizeIs(3)); +} + } // namespace } // namespace cel::ast::internal diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 8a75cca68..1b49917e1 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -273,17 +273,19 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { google::api::expr::runtime::ExecutionPath* path, const cel::RuntimeOptions& options, const absl::flat_hash_map>& constant_idents, - google::protobuf::Arena* constant_arena, + google::protobuf::Arena* constant_arena, bool updated_constant_folding, bool enable_comprehension_vulnerability_check, google::api::expr::runtime::BuilderWarnings* warnings, bool enable_regex_precompilation, const absl::flat_hash_map* reference_map, - google::protobuf::Arena* arena) + google::protobuf::Arena* arena, PlannerContext::ProgramTree& program_tree, + PlannerContext& extension_context) : resolver_(resolver), - flattened_path_(path), + execution_path_(path), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), + parent_expr_(nullptr), options_(options), constant_idents_(constant_idents), constant_arena_(constant_arena), @@ -293,17 +295,59 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { enable_regex_precompilation_(enable_regex_precompilation), regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), - arena_(arena) {} + arena_(arena), + program_tree_(program_tree), + extension_context_(extension_context) { + if (updated_constant_folding) { + constexpr int kDefaultConstFoldStackLimit = 64; + constant_folding_.emplace(kDefaultConstFoldStackLimit, constant_arena_); + } + } void PreVisitExpr(const cel::ast::internal::Expr* expr, const cel::ast::internal::SourcePosition*) override { ValidateOrError( !absl::holds_alternative(expr->expr_kind()), "Invalid empty expression"); + if (!progress_status_.ok()) { + return; + } + // TODO(issues/5): this will be generalized later. + if (!(constant_folding_.has_value())) { + return; + } + PlannerContext::ProgramInfo& info = program_tree_[expr]; + info.range_start = execution_path_->size(); + info.parent = parent_expr_; + if (parent_expr_ != nullptr) { + program_tree_[parent_expr_].children.push_back(expr); + } + parent_expr_ = expr; + absl::Status status = + constant_folding_->OnPreVisit(extension_context_, *expr); + if (!status.ok()) { + SetProgressStatusError(status); + } } - void PostVisitExpr(const cel::ast::internal::Expr*, - const cel::ast::internal::SourcePosition*) override {} + void PostVisitExpr(const cel::ast::internal::Expr* expr, + const cel::ast::internal::SourcePosition*) override { + if (!progress_status_.ok()) { + return; + } + // TODO(issues/5): this will be generalized later. + if (!constant_folding_.has_value()) { + return; + } + PlannerContext::ProgramInfo& info = program_tree_[expr]; + info.range_len = execution_path_->size() - info.range_start; + parent_expr_ = info.parent; + absl::Status status = + constant_folding_->OnPostVisit(extension_context_, *expr); + if (!status.ok()) { + SetProgressStatusError(status); + } + } void PostVisitConst(const cel::ast::internal::Constant* const_expr, const cel::ast::internal::Expr* expr, @@ -708,7 +752,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { std::unique_ptr> step) { if (step.ok() && progress_status_.ok()) { - flattened_path_->push_back(*std::move(step)); + execution_path_->push_back(*std::move(step)); } else { SetProgressStatusError(step.status()); } @@ -717,7 +761,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { void AddStep( std::unique_ptr step) { if (progress_status_.ok()) { - flattened_path_->push_back(std::move(step)); + execution_path_->push_back(std::move(step)); } } @@ -728,7 +772,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { } // Index of the next step to be inserted. - int GetCurrentIndex() const { return flattened_path_->size(); } + int GetCurrentIndex() const { return execution_path_->size(); } CondVisitor* FindCondVisitor(const cel::ast::internal::Expr* expr) const { if (cond_visitor_stack_.empty()) { @@ -787,7 +831,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { } const google::api::expr::runtime::Resolver& resolver_; - google::api::expr::runtime::ExecutionPath* flattened_path_; + google::api::expr::runtime::ExecutionPath* execution_path_; absl::Status progress_status_; std::stack< @@ -806,10 +850,16 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { // field is used as marker suppressing CelExpression creation for SELECTs. const cel::ast::internal::Expr* resolved_select_expr_; + // Used for assembling a temporary tree mapping program segments + // to source expr nodes. + const cel::ast::internal::Expr* parent_expr_; + const cel::RuntimeOptions& options_; const absl::flat_hash_map>& constant_idents_; google::protobuf::Arena* constant_arena_; + absl::optional + constant_folding_; std::stack comprehension_stack_; @@ -822,6 +872,8 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { reference_map_; google::protobuf::Arena* const arena_; + PlannerContext::ProgramTree& program_tree_; + PlannerContext extension_context_; }; void BinaryCondVisitor::PreVisit(const cel::ast::internal::Expr* expr) { @@ -1251,7 +1303,10 @@ FlatExprBuilder::CreateExpressionImpl( options_.enable_qualified_type_identifiers); absl::flat_hash_map> constant_idents; - PlannerContext extension_context(resolver, warnings_builder); + PlannerContext::ProgramTree program_tree; + PlannerContext extension_context(resolver, *GetTypeRegistry(), options_, + warnings_builder, execution_path, + program_tree); auto& ast_impl = AstImpl::CastFromPublicAst(ast); const cel::ast::internal::Expr* effective_expr = &ast_impl.root_expr(); @@ -1266,7 +1321,7 @@ FlatExprBuilder::CreateExpressionImpl( } cel::ast::internal::Expr const_fold_buffer; - if (constant_folding_) { + if (constant_folding_ && !updated_constant_folding_) { cel::ast::internal::FoldConstants( ast_impl.root_expr(), this->GetRegistry()->InternalGetRegistry(), constant_arena_, constant_idents, const_fold_buffer); @@ -1277,8 +1332,9 @@ FlatExprBuilder::CreateExpressionImpl( FlatExprVisitor visitor( resolver, &execution_path, options_, constant_idents, constant_arena_, - enable_comprehension_vulnerability_check_, &warnings_builder, - enable_regex_precompilation_, &ast_impl.reference_map(), arena.get()); + updated_constant_folding_, enable_comprehension_vulnerability_check_, + &warnings_builder, enable_regex_precompilation_, + &ast_impl.reference_map(), arena.get(), program_tree, extension_context); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index ece6a25dc..f2c934293 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -44,8 +44,10 @@ class FlatExprBuilder : public CelExpressionBuilder { // Toggle constant folding optimization. By default it is not enabled. // The provided arena is used to hold the generated constants. - void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { + void set_constant_folding(bool enabled, google::protobuf::Arena* arena, + bool updated = false) { constant_folding_ = enabled; + updated_constant_folding_ = updated; constant_arena_ = arena; } @@ -101,6 +103,7 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_regex_precompilation_ = false; bool enable_comprehension_vulnerability_check_ = false; bool constant_folding_ = false; + bool updated_constant_folding_ = false; google::protobuf::Arena* constant_arena_ = nullptr; }; diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc new file mode 100644 index 000000000..1ddde6902 --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/compiler/flat_expr_builder_extensions.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "base/ast_internal.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +ExecutionPathView PlannerContext::GetSubplan( + const cel::ast::internal::Expr& node) const { + auto iter = program_tree_.find(&node); + if (iter == program_tree_.end()) { + return {}; + } + + const ProgramInfo& info = iter->second; + + if (info.range_len == -1) { + // Initial planning for this node hasn't finished. + return {}; + } + + return absl::MakeConstSpan(execution_path_) + .subspan(info.range_start, info.range_len); +} + +absl::Status PlannerContext::ReplaceSubplan( + const cel::ast::internal::Expr& node, ExecutionPath path) { + auto iter = program_tree_.find(&node); + if (iter == program_tree_.end()) { + return absl::InternalError("attempted to rewrite unknown program step"); + } + + ProgramInfo& info = iter->second; + + if (info.range_len == -1) { + // Initial planning for this node hasn't finished. + return absl::InternalError( + "attempted to rewrite program step before completion."); + } + + int new_len = path.size(); + int old_len = info.range_len; + int delta = new_len - old_len; + + // If the replacement is differently sized, insert or erase program step + // slots at the replacement point before applying the replacement steps. + if (delta > 0) { + // Insert enough spaces to accommodate the replacement plan. + for (int i = 0; i < delta; ++i) { + execution_path_.insert( + execution_path_.begin() + info.range_start + info.range_len, nullptr); + } + } else if (delta < 0) { + // Erase spaces down to the size of the new sub plan. + execution_path_.erase(execution_path_.begin() + info.range_start, + execution_path_.begin() + info.range_start - delta); + } + + absl::c_move(std::move(path), execution_path_.begin() + info.range_start); + + info.range_len = new_len; + + // Adjust program range for parent and sibling expr nodes if we needed to + // realign them for the replacement. Note: the program structure is only + // maintained for the immediate neighborhood of node being processed by the + // planner, so descendants are not recursively updated. + auto parent_iter = program_tree_.find(info.parent); + if (parent_iter != program_tree_.end() && delta != 0) { + ProgramInfo& parent_info = parent_iter->second; + if (parent_info.range_len != -1) { + parent_info.range_len += delta; + } + + int idx = -1; + for (int i = 0; i < parent_info.children.size(); ++i) { + if (parent_info.children[i] == &node) { + idx = i; + break; + } + } + if (idx > -1) { + for (int j = idx + 1; j < parent_info.children.size(); ++j) { + program_tree_[parent_info.children[j]].range_start += delta; + } + } + } + + // Invalidate any program tree information for dependencies of the rewritten + // node. + for (const cel::ast::internal::Expr* e : info.children) { + program_tree_.erase(e); + } + + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index c66141fa3..1a1712ed2 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -22,26 +22,65 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ +#include + +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "base/ast.h" +#include "base/ast_internal.h" #include "eval/compiler/resolver.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_build_warning.h" +#include "eval/public/cel_type_registry.h" +#include "runtime/runtime_options.h" namespace google::api::expr::runtime { // Class representing FlatExpr internals exposed to extensions. class PlannerContext { public: + struct ProgramInfo { + int range_start; + int range_len = -1; + const cel::ast::internal::Expr* parent = nullptr; + std::vector children; + }; + + using ProgramTree = + absl::flat_hash_map; + explicit PlannerContext(const Resolver& resolver, - BuilderWarnings& builder_warnings) - : resolver_(resolver), builder_warnings_(builder_warnings) {} + const CelTypeRegistry& type_registry, + const cel::RuntimeOptions& options, + BuilderWarnings& builder_warnings, + ExecutionPath& execution_path, + ProgramTree& program_tree) + : resolver_(resolver), + type_registry_(type_registry), + options_(options), + builder_warnings_(builder_warnings), + execution_path_(execution_path), + program_tree_(program_tree) {} + + // Note: this is invalidated after a sibling or parent is updated. + ExecutionPathView GetSubplan(const cel::ast::internal::Expr& node) const; + + // Note: this can only safely be called on the node being visited. + absl::Status ReplaceSubplan(const cel::ast::internal::Expr& node, + ExecutionPath path); const Resolver& resolver() const { return resolver_; } + const CelTypeRegistry& type_registry() const { return type_registry_; } + const cel::RuntimeOptions& options() const { return options_; } BuilderWarnings& builder_warnings() { return builder_warnings_; } private: const Resolver& resolver_; + const CelTypeRegistry& type_registry_; + const cel::RuntimeOptions& options_; BuilderWarnings& builder_warnings_; + ExecutionPath& execution_path_; + ProgramTree& program_tree_; }; // Interface for Ast Transforms. diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc new file mode 100644 index 000000000..805f5bdb6 --- /dev/null +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -0,0 +1,266 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/compiler/flat_expr_builder_extensions.h" + +#include + +#include "absl/status/status.h" +#include "base/ast_internal.h" +#include "eval/compiler/resolver.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" +#include "eval/public/cel_type_registry.h" +#include "internal/testing.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::ast::internal::Constant; +using ::cel::ast::internal::Expr; +using ::cel::ast::internal::NullValue; +using testing::ElementsAre; +using testing::IsEmpty; +using cel::internal::StatusIs; + +class PlannerContextTest : public testing::Test { + public: + PlannerContextTest() + : type_registry_(), + function_registry_(), + resolver_("", function_registry_, &type_registry_) {} + void SetUp() override {} + + protected: + CelTypeRegistry type_registry_; + cel::FunctionRegistry function_registry_; + cel::RuntimeOptions options_; + Resolver resolver_; + BuilderWarnings builder_warnings_; +}; + +MATCHER_P(UniquePtrHolds, ptr, "") { + const auto& got = arg; + return ptr == got.get(); +} + +// simulate a program of: +// a +// / \ +// b c +absl::StatusOr InitSimpleTree( + const Expr& a, const Expr& b, Expr& c, PlannerContext::ProgramTree& tree) { + Constant null; + null.set_null_value(NullValue::kNullValue); + + CEL_ASSIGN_OR_RETURN(auto a_step, CreateConstValueStep(null, -1)); + CEL_ASSIGN_OR_RETURN(auto b_step, CreateConstValueStep(null, -1)); + CEL_ASSIGN_OR_RETURN(auto c_step, CreateConstValueStep(null, -1)); + + ExecutionPath path; + path.push_back(std::move(b_step)); + path.push_back(std::move(c_step)); + path.push_back(std::move(a_step)); + + PlannerContext::ProgramInfo& a_info = tree[&a]; + a_info.range_start = 0; + a_info.range_len = 3; + a_info.children = {&b, &c}; + + PlannerContext::ProgramInfo& b_info = tree[&b]; + b_info.range_start = 0; + b_info.range_len = 1; + b_info.parent = &a; + + PlannerContext::ProgramInfo& c_info = tree[&c]; + c_info.range_start = 1; + c_info.range_len = 1; + c_info.parent = &a; + + return path; +} + +TEST_F(PlannerContextTest, GetPlan) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(b_step_ptr))); + + EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(c_step_ptr))); + + Expr d; + EXPECT_THAT(context.GetSubplan(d), IsEmpty()); +} + +TEST_F(PlannerContextTest, ReplacePlan) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + ExecutionPath new_a; + Constant null; + null.set_null_value(NullValue::kNullValue); + ASSERT_OK_AND_ASSIGN(auto new_a_step, CreateConstValueStep(null, -1)); + const ExpressionStep* new_a_step_ptr = new_a_step.get(); + new_a.push_back(std::move(new_a_step)); + + ASSERT_OK(context.ReplaceSubplan(a, std::move(new_a))); + + EXPECT_THAT(context.GetSubplan(a), + ElementsAre(UniquePtrHolds(new_a_step_ptr))); + EXPECT_THAT(context.GetSubplan(b), IsEmpty()); +} + +TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + ASSERT_OK(context.ReplaceSubplan(c, {})); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(a_step_ptr))); + EXPECT_THAT(context.GetSubplan(c), IsEmpty()); +} + +TEST_F(PlannerContextTest, ReplacePlanUpdatesSibling) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + ExecutionPath new_b; + Constant null; + null.set_null_value(NullValue::kNullValue); + ASSERT_OK_AND_ASSIGN(auto b1_step, CreateConstValueStep(null, -1)); + const ExpressionStep* b1_step_ptr = b1_step.get(); + new_b.push_back(std::move(b1_step)); + ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(null, -1)); + const ExpressionStep* b2_step_ptr = b2_step.get(); + new_b.push_back(std::move(b2_step)); + + ASSERT_OK(context.ReplaceSubplan(b, std::move(new_b))); + + EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(c_step_ptr))); + EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(b1_step_ptr), + UniquePtrHolds(b2_step_ptr))); + EXPECT_THAT( + context.GetSubplan(a), + ElementsAre(UniquePtrHolds(b1_step_ptr), UniquePtrHolds(b2_step_ptr), + UniquePtrHolds(c_step_ptr), UniquePtrHolds(a_step_ptr))); +} + +TEST_F(PlannerContextTest, ReplacePlanFailsOnUpdatedNode) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + ASSERT_OK(context.ReplaceSubplan(a, {})); + EXPECT_THAT(context.ReplaceSubplan(b, {}), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(PlannerContextTest, ReplacePlanFailsOnUnfinishedNode) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + tree[&a].range_len = -1; + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), IsEmpty()); + + EXPECT_THAT(context.ReplaceSubplan(a, {}), + StatusIs(absl::StatusCode::kInternal)); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 06bce0513..2912fa087 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -16,6 +16,7 @@ #include "eval/compiler/flat_expr_builder.h" +#include #include #include #include @@ -25,18 +26,16 @@ #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/descriptor.pb.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/dynamic_message.h" -#include "google/protobuf/message.h" -#include "google/protobuf/text_format.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "base/function.h" +#include "base/function_descriptor.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" @@ -46,9 +45,11 @@ #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_adapter.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" @@ -60,17 +61,23 @@ #include "internal/testing.h" #include "parser/parser.h" #include "runtime/runtime_options.h" +#include "google/protobuf/dynamic_message.h" namespace google::api::expr::runtime { namespace { +using ::cel::Handle; +using ::cel::Value; using ::google::api::expr::v1alpha1::CheckedExpr; using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::api::expr::v1alpha1::SourceInfo; +using testing::_; using testing::Eq; using testing::HasSubstr; +using testing::SizeIs; +using testing::Truly; using cel::internal::StatusIs; inline constexpr absl::string_view kSimpleTestMessageDescriptorSetFile = @@ -2075,6 +2082,182 @@ INSTANTIATE_TEST_SUITE_P( }, test::IsCelTimestamp(absl::FromUnixSeconds(20))}})); +struct ConstantFoldingTestCase { + std::string test_name; + std::string expr; + test::CelValueMatcher matcher; + absl::flat_hash_map values; +}; + +class UnknownFunctionImpl : public cel::Function { + absl::StatusOr> Invoke( + const cel::Function::InvokeContext& ctx, + absl::Span> args) const override { + return ctx.value_factory().CreateUnknownValue(); + } +}; + +absl::StatusOr> +CreateConstantFoldingConformanceTestExprBuilder( + const InterpreterOptions& options) { + auto builder = + google::api::expr::runtime::CreateCelExpressionBuilder(options); + CEL_RETURN_IF_ERROR( + RegisterBuiltinFunctions(builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( + cel::FunctionDescriptor("LazyFunction", false, {}))); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( + cel::FunctionDescriptor("LazyFunction", false, {cel::Kind::kBool}))); + CEL_RETURN_IF_ERROR(builder->GetRegistry()->Register( + cel::FunctionDescriptor("UnknownFunction", false, {}), + std::make_unique())); + return builder; +} + +class ConstantFoldingConformanceTest + : public ::testing::TestWithParam { + protected: + google::protobuf::Arena arena_; +}; + +TEST_P(ConstantFoldingConformanceTest, Legacy) { + InterpreterOptions options; + options.constant_folding = true; + options.constant_arena = &arena_; + options.enable_updated_constant_folding = false; + // Check interaction between const folding and list append optimizations. + options.enable_comprehension_list_append = true; + + const ConstantFoldingTestCase& p = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(p.expr)); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK(activation.InsertFunction( + PortableUnaryFunctionAdapter::Create( + "LazyFunction", false, + [](google::protobuf::Arena* arena, bool val) { return val; }))); + for (auto iter = p.values.begin(); iter != p.values.end(); ++iter) { + activation.InsertValue(iter->first, CelValue::CreateInt64(iter->second)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena_)); + // Check that none of the memoized constants are being mutated. + ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena_)); + EXPECT_THAT(result, p.matcher); +} + +TEST_P(ConstantFoldingConformanceTest, Updated) { + InterpreterOptions options; + options.constant_folding = true; + options.constant_arena = &arena_; + options.enable_updated_constant_folding = true; + // Check interaction between const folding and list append optimizations. + options.enable_comprehension_list_append = true; + + const ConstantFoldingTestCase& p = GetParam(); + ASSERT_OK_AND_ASSIGN( + auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(p.expr)); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + Activation activation; + ASSERT_OK(activation.InsertFunction( + PortableUnaryFunctionAdapter::Create( + "LazyFunction", false, + [](google::protobuf::Arena* arena, bool val) { return val; }))); + + for (auto iter = p.values.begin(); iter != p.values.end(); ++iter) { + activation.InsertValue(iter->first, CelValue::CreateInt64(iter->second)); + } + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena_)); + // Check that none of the memoized constants are being mutated. + ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena_)); + EXPECT_THAT(result, p.matcher); +} + +INSTANTIATE_TEST_SUITE_P( + Exprs, ConstantFoldingConformanceTest, + ::testing::ValuesIn(std::vector{ + {"simple_add", "1 + 2 + 3", test::IsCelInt64(6)}, + {"add_with_var", + "1 + (2 + (3 + id))", + test::IsCelInt64(10), + {{"id", 4}}}, + {"const_list", "[1, 2, 3, 4]", test::IsCelList(_)}, + {"mixed_const_list", + "[1, 2, 3, 4] + [id]", + test::IsCelList(_), + {{"id", 5}}}, + {"create_struct", "{'abc': 'def', 'def': 'efg', 'efg': 'hij'}", + Truly([](const CelValue& v) { return v.IsMap(); })}, + {"field_selection", "{'abc': 123}.abc == 123", test::IsCelBool(true)}, + {"type_coverage", + // coverage for constant literals, type() is used to make the list + // homogenous. + R"cel( + [type(bool), + type(123), + type(123u), + type(12.3), + type(b'123'), + type('123'), + type(null), + type(timestamp(0)), + type(duration('1h')) + ])cel", + test::IsCelList(SizeIs(9))}, + {"lazy_function", "true || LazyFunction()", test::IsCelBool(true)}, + {"lazy_function_called", "LazyFunction(true) || false", + test::IsCelBool(true)}, + {"unknown_function", "UnknownFunction() && false", + test::IsCelBool(false)}, + {"nested_comprehension", + "[1, 2, 3, 4].all(x, [5, 6, 7, 8].all(y, x < y))", + test::IsCelBool(true)}, + // Implementation detail: map and filter use replace the accu_init + // expr with a special mutable list to avoid quadratic memory usage + // building the projected list. + {"map", "[1, 2, 3, 4].map(x, x * 2).size() == 4", + test::IsCelBool(true)}, + {"str_cat", + "'1234567890' + '1234567890' + '1234567890' + '1234567890' + " + "'1234567890'", + test::IsCelString( + "12345678901234567890123456789012345678901234567890")}})); + +// Check that list literals are pre-computed +TEST(UpdatedConstantFolding, FoldsLists) { + InterpreterOptions options; + google::protobuf::Arena arena; + options.constant_folding = true; + options.constant_arena = &arena; + options.enable_updated_constant_folding = true; + + ASSERT_OK_AND_ASSIGN( + auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + parser::Parse("[1] + [2] + [3] + [4] + [5] + [6] + [7] " + "+ [8] + [9] + [10] + [11] + [12]")); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + Activation activation; + int before_size = arena.SpaceUsed(); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + // Some incidental allocations are expected related to interop. + EXPECT_LT(arena.SpaceUsed() - before_size, 100); + EXPECT_THAT(result, test::IsCelList(SizeIs(12))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index ec685d988..706ec5403 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -52,6 +52,7 @@ struct InterpreterOptions { // Note that expression tracing applies a modified expression if this option // is enabled. bool constant_folding = false; + bool enable_updated_constant_folding = false; google::protobuf::Arena* constant_arena = nullptr; // Enable comprehension expressions (e.g. exists, all) diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 9ea9ca1c7..e8950af42 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -51,7 +51,8 @@ std::unique_ptr CreatePortableExprBuilder( options.enable_comprehension_vulnerability_check); builder->set_enable_regex_precompilation(options.enable_regex_precompilation); builder->set_constant_folding(options.constant_folding, - options.constant_arena); + options.constant_arena, + options.enable_updated_constant_folding); return builder; } diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 13399fd5e..97c72ec9a 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -8,9 +8,9 @@ licenses(["notice"]) exports_files(["LICENSE"]) -cc_test( - name = "benchmark_test", - size = "small", +cc_library( + name = "benchmark_testlib", + testonly = True, srcs = [ "benchmark_test.cc", ], @@ -29,17 +29,41 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", - "@com_github_google_benchmark//:benchmark", - "@com_github_google_benchmark//:benchmark_main", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], + alwayslink = True, +) + +cc_test( + name = "benchmark_test", + size = "small", + deps = [ + ":benchmark_testlib", + "@com_github_google_benchmark//:benchmark", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +# copybara:strip_begin +# benchy will still need the enable_optimizations flag since it isn't using blaze run directly. +# copybara:strip_end +cc_test( + name = "const_folding_benchmark_test", + size = "small", + args = ["--enable_optimizations"], + deps = [ + ":benchmark_testlib", + "@com_github_google_benchmark//:benchmark", + "@com_github_google_benchmark//:benchmark_main", + ], ) cc_test( diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index 647823618..95b538d60 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -11,6 +11,7 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" +#include "absl/flags/flag.h" #include "absl/strings/match.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" @@ -25,6 +26,9 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" +#include "google/protobuf/arena.h" + +ABSL_FLAG(bool, enable_optimizations, false, "enable const folding opt"); namespace google { namespace api { @@ -37,13 +41,27 @@ using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::rpc::context::AttributeContext; +InterpreterOptions GetOptions(google::protobuf::Arena& arena) { + InterpreterOptions options; + + if (absl::GetFlag(FLAGS_enable_optimizations)) { + options.enable_updated_constant_folding = true; + options.constant_arena = &arena; + options.constant_folding = true; + } + + return options; +} + // Benchmark test // Evaluates cel expression: // '1 + 1 + 1 .... +1' static void BM_Eval(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - ASSERT_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -84,9 +102,11 @@ absl::Status EmptyCallback(int64_t expr_id, const CelValue& value, // Traces cel expression with an empty callback: // '1 + 1 + 1 .... +1' static void BM_Eval_Trace(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - ASSERT_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -124,9 +144,11 @@ BENCHMARK(BM_Eval_Trace)->Range(1, 10000); // Evaluates cel expression: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - ASSERT_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -164,9 +186,11 @@ BENCHMARK(BM_EvalString)->Range(1, 10000); // Traces cel expression with an empty callback: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString_Trace(benchmark::State& state) { - auto builder = CreateCelExpressionBuilder(); - auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); - ASSERT_OK(reg_status); + google::protobuf::Arena arena; + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); @@ -258,12 +282,12 @@ void BM_PolicySymbolic(benchmark::State& state) { ]) ))cel")); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.constant_folding = true; options.constant_arena = &arena; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( @@ -316,8 +340,10 @@ void BM_PolicySymbolicMap(benchmark::State& state) { request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( @@ -347,8 +373,10 @@ void BM_PolicySymbolicProto(benchmark::State& state) { request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( @@ -435,10 +463,12 @@ void BM_Comprehension(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { @@ -466,10 +496,11 @@ void BM_Comprehension_Trace(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { @@ -487,8 +518,11 @@ void BM_HasMap(benchmark::State& state) { Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.path) && !has(request.ip)")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -514,8 +548,10 @@ void BM_HasProto(benchmark::State& state) { Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.path) && !has(request.ip)")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -541,8 +577,10 @@ void BM_HasProtoMap(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.headers.create_time) && " "!has(request.headers.update_time)")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -567,8 +605,10 @@ void BM_ReadProtoMap(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( request.headers.create_time == "2021-01-01" )cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -593,8 +633,10 @@ void BM_NestedProtoFieldRead(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !request.a.b.c.d.e )cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -619,8 +661,10 @@ void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !request.a.b.c.d.e )cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -644,8 +688,10 @@ void BM_ProtoStructAccess(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' )cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -672,8 +718,10 @@ void BM_ProtoListAccess(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels )cel")); - auto builder = CreateCelExpressionBuilder(); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + InterpreterOptions options = GetOptions(arena); + auto builder = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); @@ -806,10 +854,11 @@ void BM_NestedComprehension(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); @@ -838,11 +887,11 @@ void BM_NestedComprehension_Trace(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); @@ -871,11 +920,11 @@ void BM_ListComprehension(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); @@ -904,11 +953,11 @@ void BM_ListComprehension_Trace(benchmark::State& state) { ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list", CelValue::CreateList(&cel_list)); - InterpreterOptions options; + InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto builder = CreateCelExpressionBuilder(options); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index 42ed8fee0..64f9ad693 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -38,7 +38,11 @@ namespace { using google::api::expr::v1alpha1::ParsedExpr; -enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1 }; +enum BenchmarkParam : int { + kDefault = 0, + kFoldConstants = 1, + kUpdatedFoldConstants = 2 +}; void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { @@ -50,6 +54,26 @@ void BM_RegisterBuiltins(benchmark::State& state) { BENCHMARK(BM_RegisterBuiltins); +InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { + InterpreterOptions options; + + switch (param) { + case BenchmarkParam::kFoldConstants: + options.constant_arena = &arena; + options.constant_folding = true; + break; + case BenchmarkParam::kUpdatedFoldConstants: + options.constant_arena = &arena; + options.constant_folding = true; + options.enable_updated_constant_folding = true; + break; + case BenchmarkParam::kDefault: + options.constant_folding = false; + break; + } + return options; +} + void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); @@ -62,17 +86,7 @@ void BM_SymbolicPolicy(benchmark::State& state) { ))cel")); google::protobuf::Arena arena; - InterpreterOptions options; - - switch (param) { - case BenchmarkParam::kFoldConstants: - options.constant_arena = &arena; - options.constant_folding = true; - break; - case BenchmarkParam::kDefault: - options.constant_folding = false; - break; - } + InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); @@ -82,12 +96,14 @@ void BM_SymbolicPolicy(benchmark::State& state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); } } BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kUpdatedFoldConstants); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); @@ -96,18 +112,8 @@ void BM_NestedComprehension(benchmark::State& state) { [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) )")); - InterpreterOptions options; google::protobuf::Arena arena; - - switch (param) { - case BenchmarkParam::kFoldConstants: - options.constant_arena = &arena; - options.constant_folding = true; - break; - case BenchmarkParam::kDefault: - options.constant_folding = false; - break; - } + InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); @@ -117,12 +123,14 @@ void BM_NestedComprehension(benchmark::State& state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); } } BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kUpdatedFoldConstants); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); @@ -134,18 +142,8 @@ void BM_Comparisons(benchmark::State& state) { && v11 != v12 && v12 != v13 )")); - InterpreterOptions options; google::protobuf::Arena arena; - - switch (param) { - case BenchmarkParam::kFoldConstants: - options.constant_arena = &arena; - options.constant_folding = true; - break; - case BenchmarkParam::kDefault: - options.constant_folding = false; - break; - } + InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); @@ -155,12 +153,14 @@ void BM_Comparisons(benchmark::State& state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); } } BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kUpdatedFoldConstants); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); @@ -177,18 +177,8 @@ void BM_StringConcat(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(source)); - InterpreterOptions options; google::protobuf::Arena arena; - - switch (param) { - case BenchmarkParam::kFoldConstants: - options.constant_arena = &arena; - options.constant_folding = true; - break; - case BenchmarkParam::kDefault: - options.constant_folding = false; - break; - } + InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); @@ -198,6 +188,7 @@ void BM_StringConcat(benchmark::State& state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); + arena.Reset(); } } @@ -211,7 +202,12 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kUpdatedFoldConstants, 2}) + ->Args({BenchmarkParam::kUpdatedFoldConstants, 4}) + ->Args({BenchmarkParam::kUpdatedFoldConstants, 8}) + ->Args({BenchmarkParam::kUpdatedFoldConstants, 16}) + ->Args({BenchmarkParam::kUpdatedFoldConstants, 32}); } // namespace } // namespace google::api::expr::runtime From 6b099c401fdfa5409c8482a3b5aa4b5981a12a98 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 3 May 2023 13:25:24 +0000 Subject: [PATCH 239/303] Update `InlineValue` to correctly account for borrowed `StringValue` and `BytesValue` PiperOrigin-RevId: 529071697 --- base/internal/value.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/base/internal/value.h b/base/internal/value.h index 19ee4c317..7a8621eed 100644 --- a/base/internal/value.h +++ b/base/internal/value.h @@ -85,7 +85,10 @@ struct InlineValue final { absl::Time time_value; absl::Status status_value; absl::Cord cord_value; - absl::string_view string_value; + struct { + absl::string_view string_value; + uintptr_t owner; + } string_value; Handle type_value; struct { Handle type; From 5c523fff5cd852647e6d0e0bc9fa17fbac53e5f0 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 3 May 2023 22:36:51 +0000 Subject: [PATCH 240/303] Implement support for JSON-like values PiperOrigin-RevId: 529210191 --- base/internal/data.h | 10 +- base/owner.h | 6 +- base/value_factory.h | 11 +- extensions/protobuf/BUILD | 2 + extensions/protobuf/internal/wrappers.cc | 52 +- extensions/protobuf/internal/wrappers.h | 10 + extensions/protobuf/struct_value.cc | 275 +++- extensions/protobuf/struct_value.h | 91 +- extensions/protobuf/struct_value_test.cc | 227 +++- extensions/protobuf/type.cc | 21 +- extensions/protobuf/value.cc | 1507 ++++++++++++++++++++-- extensions/protobuf/value.h | 65 +- extensions/protobuf/value_test.cc | 367 ++++++ 13 files changed, 2508 insertions(+), 136 deletions(-) diff --git a/base/internal/data.h b/base/internal/data.h index 270c33975..5dba01e49 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -352,16 +352,14 @@ template struct SelectMetadataImpl; template -struct SelectMetadataImpl< - T, std::enable_if_t, - std::is_base_of>>> { +struct SelectMetadataImpl>> { using type = TypeMetadata; }; template -struct SelectMetadataImpl< - T, std::enable_if_t, - std::is_base_of>>> { +struct SelectMetadataImpl>> { using type = ValueMetadata; }; diff --git a/base/owner.h b/base/owner.h index 7c41f14e6..0abe125f7 100644 --- a/base/owner.h +++ b/base/owner.h @@ -38,6 +38,8 @@ class Owner { using metadata_type = base_internal::SelectMetadata; public: + static_assert(!std::is_base_of_v); + Owner() = delete; Owner(const Owner& other) noexcept : owner_(other.owner_) { @@ -132,7 +134,9 @@ class Owner { friend class TypeFactory; friend class ValueFactory; - explicit Owner(const T* owner) : owner_(owner) {} + explicit Owner(const T* owner) : owner_(owner) { + static_assert(std::is_base_of_v); + } const T* release() { const T* owner = owner_; diff --git a/base/value_factory.h b/base/value_factory.h index 2f6c0d9f0..63fdee57e 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -81,17 +81,12 @@ class ValueFactory final { std::enable_if_t>, V>; template - using EnableIfReferent = std::enable_if_t< - std::conjunction_v, - std::is_base_of>, - V>; + using EnableIfReferent = std::enable_if_t, V>; template using EnableIfBaseOfAndReferent = std::enable_if_t< - std::conjunction_v< - std::is_base_of>, - std::is_base_of>, - std::is_base_of>>, + std::conjunction_v>, + std::is_base_of>>, V>; public: diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 11793b726..682581945 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -148,6 +148,7 @@ cc_library( "//base:allocator", "//base:handle", "//base:kind", + "//base:owner", "//base:type", "//base:value", "//eval/internal:errors", @@ -160,6 +161,7 @@ cc_library( "//internal:casts", "//internal:rtti", "//internal:status_macros", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/log:die_if_null", diff --git a/extensions/protobuf/internal/wrappers.cc b/extensions/protobuf/internal/wrappers.cc index 3dd6d1a79..0315f0b7f 100644 --- a/extensions/protobuf/internal/wrappers.cc +++ b/extensions/protobuf/internal/wrappers.cc @@ -71,6 +71,12 @@ absl::StatusOr UnwrapBytesValueProto( &google::protobuf::Reflection::GetCord); } +absl::StatusOr UnwrapFloatValueProto(const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_FLOAT, + &google::protobuf::Reflection::GetFloat); +} + absl::StatusOr UnwrapDoubleValueProto(const google::protobuf::Message& message) { const auto* desc = message.GetDescriptor(); if (ABSL_PREDICT_FALSE(desc == nullptr)) { @@ -78,9 +84,7 @@ absl::StatusOr UnwrapDoubleValueProto(const google::protobuf::Message& m absl::StrCat(message.GetTypeName(), " missing descriptor")); } if (desc->full_name() == "google.protobuf.FloatValue") { - return UnwrapValueProto( - message, google::protobuf::FieldDescriptor::CPPTYPE_FLOAT, - &google::protobuf::Reflection::GetFloat); + return UnwrapFloatValueProto(message); } if (desc->full_name() == "google.protobuf.DoubleValue") { return UnwrapValueProto( @@ -98,19 +102,27 @@ absl::StatusOr UnwrapIntValueProto(const google::protobuf::Message& mes absl::StrCat(message.GetTypeName(), " missing descriptor")); } if (desc->full_name() == "google.protobuf.Int32Value") { - return UnwrapValueProto( - message, google::protobuf::FieldDescriptor::CPPTYPE_INT32, - &google::protobuf::Reflection::GetInt32); + return UnwrapInt32ValueProto(message); } if (desc->full_name() == "google.protobuf.Int64Value") { - return UnwrapValueProto( - message, google::protobuf::FieldDescriptor::CPPTYPE_INT64, - &google::protobuf::Reflection::GetInt64); + return UnwrapInt64ValueProto(message); } return absl::InvalidArgumentError( absl::StrCat(message.GetTypeName(), " is not int-like")); } +absl::StatusOr UnwrapInt32ValueProto(const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_INT32, + &google::protobuf::Reflection::GetInt32); +} + +absl::StatusOr UnwrapInt64ValueProto(const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_INT64, + &google::protobuf::Reflection::GetInt64); +} + absl::StatusOr> UnwrapStringValueProto(const google::protobuf::Message& message) { const auto* desc = message.GetDescriptor(); @@ -155,17 +167,27 @@ absl::StatusOr UnwrapUIntValueProto(const google::protobuf::Message& m absl::StrCat(message.GetTypeName(), " missing descriptor")); } if (desc->full_name() == "google.protobuf.UInt32Value") { - return UnwrapValueProto( - message, google::protobuf::FieldDescriptor::CPPTYPE_UINT32, - &google::protobuf::Reflection::GetUInt32); + return UnwrapUInt32ValueProto(message); } if (desc->full_name() == "google.protobuf.UInt64Value") { - return UnwrapValueProto( - message, google::protobuf::FieldDescriptor::CPPTYPE_UINT64, - &google::protobuf::Reflection::GetUInt64); + return UnwrapUInt64ValueProto(message); } return absl::InvalidArgumentError( absl::StrCat(message.GetTypeName(), " is not uint-like")); } +absl::StatusOr UnwrapUInt32ValueProto( + const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_UINT32, + &google::protobuf::Reflection::GetUInt32); +} + +absl::StatusOr UnwrapUInt64ValueProto( + const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_UINT64, + &google::protobuf::Reflection::GetUInt64); +} + } // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/wrappers.h b/extensions/protobuf/internal/wrappers.h index 125d0f68f..c0c5db15a 100644 --- a/extensions/protobuf/internal/wrappers.h +++ b/extensions/protobuf/internal/wrappers.h @@ -28,15 +28,25 @@ absl::StatusOr UnwrapBoolValueProto(const google::protobuf::Message& messa absl::StatusOr UnwrapBytesValueProto( const google::protobuf::Message& message); +absl::StatusOr UnwrapFloatValueProto(const google::protobuf::Message& message); + absl::StatusOr UnwrapDoubleValueProto(const google::protobuf::Message& message); absl::StatusOr UnwrapIntValueProto(const google::protobuf::Message& message); +absl::StatusOr UnwrapInt32ValueProto(const google::protobuf::Message& message); + +absl::StatusOr UnwrapInt64ValueProto(const google::protobuf::Message& message); + absl::StatusOr> UnwrapStringValueProto(const google::protobuf::Message& message); absl::StatusOr UnwrapUIntValueProto(const google::protobuf::Message& message); +absl::StatusOr UnwrapUInt32ValueProto(const google::protobuf::Message& message); + +absl::StatusOr UnwrapUInt64ValueProto(const google::protobuf::Message& message); + } // namespace cel::extensions::protobuf_internal #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_WRAPPERS_H_ diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 76224c34e..04466a60b 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(issues/5): get test coverage closer to 100% before using + #include "extensions/protobuf/struct_value.h" #include @@ -93,16 +95,6 @@ namespace protobuf_internal { namespace { -Handle CreateStringValueFromView(absl::string_view value) { - return base_internal::HandleFactory::Make< - base_internal::InlinedStringViewStringValue>(value); -} - -Handle CreateBytesValueFromView(absl::string_view value) { - return base_internal::HandleFactory::Make< - base_internal::InlinedStringViewBytesValue>(value); -} - struct DebugStringFromStringWrapperVisitor final { std::string operator()(absl::string_view value) const { return StringValue::DebugString(value); @@ -113,8 +105,7 @@ struct DebugStringFromStringWrapperVisitor final { } }; -class HeapDynamicParsedProtoStructValue final - : public DynamicParsedProtoStructValue { +class HeapDynamicParsedProtoStructValue : public DynamicParsedProtoStructValue { public: HeapDynamicParsedProtoStructValue(Handle type, const google::protobuf::Message* value) @@ -792,6 +783,156 @@ class ParsedProtoListValue const google::protobuf::RepeatedFieldRef fields_; }; +// repeated google.protobuf.ListValue +template <> +class ParsedProtoListValue + : public CEL_LIST_VALUE_CLASS { + public: + ParsedProtoListValue(Handle type, + google::protobuf::RepeatedFieldRef fields) + : CEL_LIST_VALUE_CLASS(std::move(type)), fields_(std::move(fields)) {} + + std::string DebugString() const final { + std::string out; + out.push_back('['); + auto field = fields_.begin(); + if (field != fields_.end()) { + ProtoDebugStringStruct(out, *field); + ++field; + for (; field != fields_.end(); ++field) { + out.append(", "); + ProtoDebugStringStruct(out, *field); + } + } + out.push_back(']'); + return out; + } + + size_t size() const final { return fields_.size(); } + + bool empty() const final { return fields_.empty(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + std::unique_ptr scratch(fields_.NewMessage()); + const auto& field = fields_.Get(static_cast(index), scratch.get()); + if (scratch.get() == &field) { + return protobuf_internal::CreateListValue(context.value_factory(), + std::move(scratch)); + } + scratch.reset(); + return protobuf_internal::CreateBorrowedListValue( + owner_from_this(), context.value_factory(), field); + } + + private: + cel::internal::TypeInfo TypeId() const final { + return internal::TypeId>(); + } + + const google::protobuf::RepeatedFieldRef fields_; +}; + +// repeated google.protobuf.Struct +template <> +class ParsedProtoListValue + : public CEL_LIST_VALUE_CLASS { + public: + ParsedProtoListValue(Handle type, + google::protobuf::RepeatedFieldRef fields) + : CEL_LIST_VALUE_CLASS(std::move(type)), fields_(std::move(fields)) {} + + std::string DebugString() const final { + std::string out; + out.push_back('['); + auto field = fields_.begin(); + if (field != fields_.end()) { + ProtoDebugStringStruct(out, *field); + ++field; + for (; field != fields_.end(); ++field) { + out.append(", "); + ProtoDebugStringStruct(out, *field); + } + } + out.push_back(']'); + return out; + } + + size_t size() const final { return fields_.size(); } + + bool empty() const final { return fields_.empty(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + std::unique_ptr scratch(fields_.NewMessage()); + const auto& field = fields_.Get(static_cast(index), scratch.get()); + if (scratch.get() == &field) { + return protobuf_internal::CreateStruct(context.value_factory(), + std::move(scratch)); + } + scratch.reset(); + return protobuf_internal::CreateBorrowedStruct( + owner_from_this(), context.value_factory(), field); + } + + private: + cel::internal::TypeInfo TypeId() const final { + return internal::TypeId>(); + } + + const google::protobuf::RepeatedFieldRef fields_; +}; + +// repeated google.protobuf.Value +template <> +class ParsedProtoListValue + : public CEL_LIST_VALUE_CLASS { + public: + ParsedProtoListValue(Handle type, + google::protobuf::RepeatedFieldRef fields) + : CEL_LIST_VALUE_CLASS(std::move(type)), fields_(std::move(fields)) {} + + std::string DebugString() const final { + std::string out; + out.push_back('['); + auto field = fields_.begin(); + if (field != fields_.end()) { + ProtoDebugStringStruct(out, *field); + ++field; + for (; field != fields_.end(); ++field) { + out.append(", "); + ProtoDebugStringStruct(out, *field); + } + } + out.push_back(']'); + return out; + } + + size_t size() const final { return fields_.size(); } + + bool empty() const final { return fields_.empty(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + std::unique_ptr scratch(fields_.NewMessage()); + const auto& field = fields_.Get(static_cast(index), scratch.get()); + if (scratch.get() == &field) { + return protobuf_internal::CreateValue(context.value_factory(), + std::move(scratch)); + } + scratch.reset(); + return protobuf_internal::CreateBorrowedValue( + owner_from_this(), context.value_factory(), field); + } + + private: + cel::internal::TypeInfo TypeId() const final { + return internal::TypeId>(); + } + + const google::protobuf::RepeatedFieldRef fields_; +}; + // repeated google.protobuf.BoolValue template <> class ParsedProtoListValue @@ -1421,6 +1562,21 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_value.GetMessageValue())); return context.value_factory().CreateUncheckedTimestampValue(time); } + case Kind::kList: + // google.protobuf.ListValue + return protobuf_internal::CreateBorrowedListValue( + owner_from_this(), context.value_factory(), + proto_value.GetMessageValue()); + case Kind::kMap: + // google.protobuf.Struct + return protobuf_internal::CreateBorrowedStruct( + owner_from_this(), context.value_factory(), + proto_value.GetMessageValue()); + case Kind::kDyn: + // google.protobuf.Value + return protobuf_internal::CreateBorrowedValue( + owner_from_this(), context.value_factory(), + proto_value.GetMessageValue()); case Kind::kBool: { // google.protobuf.BoolValue, mapped to CEL primitive bool type for // map values. @@ -1849,6 +2005,62 @@ absl::StatusOr> ProtoStructValue::Create( return std::move(status_or_message).value(); } +absl::StatusOr> ProtoStructValue::CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& message) { + const auto* descriptor = message.GetDescriptor(); + if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { + return absl::InvalidArgumentError("message missing descriptor"); + } + CEL_ASSIGN_OR_RETURN( + auto type, + ProtoStructType::Resolve(value_factory.type_manager(), *descriptor)); + bool same_descriptors = &type->descriptor() == descriptor; + if (ABSL_PREDICT_TRUE(same_descriptors)) { + return value_factory.CreateBorrowedStructValue< + protobuf_internal::DynamicMemberParsedProtoStructValue>( + std::move(owner), std::move(type), &message); + } + const auto* prototype = type->factory_->GetPrototype(&type->descriptor()); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return absl::InternalError(absl::StrCat( + "cel: unable to get prototype for protocol buffer message \"", + type->name(), "\"")); + } + std::string serialized; + if (ABSL_PREDICT_FALSE(!message.SerializePartialToString(&serialized))) { + return absl::InternalError( + "cel: failed to serialize protocol buffer message"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + auto* value = prototype->New(arena); + if (ABSL_PREDICT_FALSE(!value->ParsePartialFromString(serialized))) { + return absl::InternalError( + "cel: failed to deserialize protocol buffer message"); + } + return value_factory.CreateBorrowedStructValue< + protobuf_internal::ArenaDynamicParsedProtoStructValue>( + std::move(owner), std::move(type), value); + } + } + auto value = absl::WrapUnique(prototype->New()); + if (ABSL_PREDICT_FALSE(!value->ParsePartialFromString(serialized))) { + return absl::InternalError( + "cel: failed to deserialize protocol buffer message"); + } + auto status_or_message = value_factory.CreateBorrowedStructValue< + protobuf_internal::HeapDynamicParsedProtoStructValue>( + std::move(owner), std::move(type), value.get()); + if (ABSL_PREDICT_FALSE(!status_or_message.ok())) { + return status_or_message.status(); + } + value.release(); + return std::move(status_or_message).value(); +} + absl::StatusOr> ProtoStructValue::Create( ValueFactory& value_factory, google::protobuf::Message&& message) { const auto* descriptor = message.GetDescriptor(); @@ -2108,6 +2320,30 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); + case Kind::kList: + // google.protobuf.ListValue + return context.value_factory() + .CreateBorrowedListValue< + ParsedProtoListValue>( + owner_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); + case Kind::kMap: + // google.protobuf.Struct + return context.value_factory() + .CreateBorrowedListValue< + ParsedProtoListValue>( + owner_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); + case Kind::kDyn: + // google.protobuf.Value. + return context.value_factory() + .CreateBorrowedListValue< + ParsedProtoListValue>( + owner_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kBool: // google.protobuf.BoolValue, mapped to CEL primitive bool type for // list elements. @@ -2264,6 +2500,21 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateUncheckedTimestampValue( timestamp); } + case Kind::kList: + // google.protobuf.ListValue + return protobuf_internal::CreateBorrowedListValue( + owner_from_this(), context.value_factory(), + reflect.GetMessage(value(), &field_desc)); + case Kind::kMap: + // google.protobuf.Struct + return protobuf_internal::CreateBorrowedStruct( + owner_from_this(), context.value_factory(), + reflect.GetMessage(value(), &field_desc)); + case Kind::kDyn: + // google.protobuf.Value + return protobuf_internal::CreateBorrowedValue( + owner_from_this(), context.value_factory(), + reflect.GetMessage(value(), &field_desc)); case Kind::kWrapper: { if (context.unbox_null_wrapper_types() && !reflect.HasField(value(), &field_desc)) { diff --git a/extensions/protobuf/struct_value.h b/extensions/protobuf/struct_value.h index 8705c227f..a081bfcaf 100644 --- a/extensions/protobuf/struct_value.h +++ b/extensions/protobuf/struct_value.h @@ -30,6 +30,7 @@ #include "absl/types/optional.h" #include "base/handle.h" #include "base/kind.h" +#include "base/owner.h" #include "base/type.h" #include "base/types/struct_type.h" #include "base/value.h" @@ -148,9 +149,18 @@ class ProtoStructValue : public CEL_STRUCT_VALUE_CLASS { static EnableIfDerivedMessage>> Create(ValueFactory& value_factory, T&& value); + template + static EnableIfDerivedMessage>> + CreateBorrowed(Owner owner, ValueFactory& value_factory, + const T& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + static absl::StatusOr> Create( ValueFactory& value_factory, const google::protobuf::Message& message); + static absl::StatusOr> CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND); + static absl::StatusOr> Create( ValueFactory& value_factory, google::protobuf::Message&& message); @@ -164,6 +174,48 @@ class ProtoStructValue : public CEL_STRUCT_VALUE_CLASS { namespace protobuf_internal { +// Declare here but implemented in value.cc to give ProtoStructValue access to +// the conversion logic in value.cc. Creates a borrowed `ListValue` over +// `google.protobuf.ListValue`. +// +// Borrowing here means we are borrowing some native representation owned by +// `owner` and creating a new value which references that native representation, +// but does not own it. +absl::StatusOr> CreateBorrowedListValue( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Declare here but implemented in value.cc to give ProtoStructValue access to +// the conversion logic in value.cc. Creates a borrowed `MapValue` over +// `google.protobuf.Struct`. +// +// Borrowing here means we are borrowing some native representation owned by +// `owner` and creating a new value which references that native representation, +// but does not own it. +absl::StatusOr> CreateBorrowedStruct( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Declare here but implemented in value.cc to give ProtoStructValue access to +// the conversion logic in value.cc. Creates a borrowed `Value` over +// `google.protobuf.Value`. +// +// Borrowing here means we are borrowing some native representation owned by +// `owner` and creating a new value which references that native representation, +// but does not own it. +absl::StatusOr> CreateBorrowedValue( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + +absl::StatusOr> CreateListValue( + ValueFactory& value_factory, std::unique_ptr value); + +absl::StatusOr> CreateStruct( + ValueFactory& value_factory, std::unique_ptr value); + +absl::StatusOr> CreateValue( + ValueFactory& value_factory, std::unique_ptr value); + // Base class of all implementations of `ProtoStructValue` that operate on // parsed protocol buffer messages. class ParsedProtoStructValue : public ProtoStructValue { @@ -258,6 +310,31 @@ class StaticParsedProtoStructValue final : public ParsedProtoStructValue { const T value_; }; +template +class HeapStaticParsedProtoStructValue : public ParsedProtoStructValue { + public: + HeapStaticParsedProtoStructValue(Handle type, const T* value) + : ParsedProtoStructValue(std::move(type)), value_(value) {} + + const google::protobuf::Message& value() const final { return *value_; } + + protected: + absl::optional ValueReference( + google::protobuf::Message& scratch, const google::protobuf::Descriptor& desc, + internal::TypeInfo type) const final { + static_cast(scratch); + static_cast(desc); + if (ABSL_PREDICT_FALSE(type != internal::TypeId())) { + return absl::nullopt; + } + ABSL_ASSERT(value().GetDescriptor() == &desc); + return &value(); + } + + private: + const T* const value_; +}; + // Base implementation of `ParsedProtoStructValue` which does not know the // concrete type of the protocol buffer message. The protocol buffer message is // referenced by pointer and is allocated with the same memory manager that @@ -289,7 +366,7 @@ class DynamicParsedProtoStructValue : public ParsedProtoStructValue { // Implementation of `DynamicParsedProtoStructValue` for Arena-based memory // managers. -class ArenaDynamicParsedProtoStructValue final +class ArenaDynamicParsedProtoStructValue : public DynamicParsedProtoStructValue { public: ArenaDynamicParsedProtoStructValue(Handle type, @@ -324,6 +401,18 @@ ProtoStructValue::Create(ValueFactory& value_factory, T&& value) { std::move(type), std::forward(value)); } +template +inline ProtoStructValue::EnableIfDerivedMessage< + T, absl::StatusOr>> +ProtoStructValue::CreateBorrowed(Owner owner, + ValueFactory& value_factory, const T& value) { + CEL_ASSIGN_OR_RETURN( + auto type, ProtoStructType::Resolve(value_factory.type_manager())); + return value_factory.CreateBorrowedStructValue< + protobuf_internal::HeapStaticParsedProtoStructValue>( + std::move(owner), std::move(type), &value); +} + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_STRUCT_VALUE_H_ diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index 87e07af88..a3453363a 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -23,6 +23,7 @@ #include "absl/time/time.h" #include "absl/types/optional.h" #include "base/internal/memory_manager_testing.h" +#include "base/testing/value_matchers.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/types/struct_type.h" @@ -40,8 +41,10 @@ namespace cel::extensions { namespace { +using ::cel_testing::ValueOf; using testing::Eq; using testing::EqualsProto; +using testing::Optional; using testing::status::CanonicalStatusIs; using cel::internal::IsOkAndHolds; @@ -397,7 +400,7 @@ TEST_P(ProtoStructValueTest, EnumHasField) { IsOkAndHolds(Eq(true))); } -TEST_P(ProtoStructValueTest, StructHasField) { +TEST_P(ProtoStructValueTest, MessageHasField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -627,6 +630,72 @@ TEST_P(ProtoStructValueTest, StringWrapperHasField) { IsOkAndHolds(Eq(true))); } +TEST_P(ProtoStructValueTest, ListValueHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("list_value")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.mutable_list_value(); + }))); + EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("list_value")), + IsOkAndHolds(Eq(true))); +} + +TEST_P(ProtoStructValueTest, StructHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("single_struct")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.mutable_single_struct(); + }))); + EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("single_struct")), + IsOkAndHolds(Eq(true))); +} + +TEST_P(ProtoStructValueTest, ValueHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("single_value")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.mutable_single_value(); + }))); + EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("single_value")), + IsOkAndHolds(Eq(true))); +} + TEST_P(ProtoStructValueTest, NullValueListHasField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; @@ -917,7 +986,7 @@ TEST_P(ProtoStructValueTest, EnumListHasField) { IsOkAndHolds(Eq(true))); } -TEST_P(ProtoStructValueTest, StructListHasField) { +TEST_P(ProtoStructValueTest, MessageListHasField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -1147,6 +1216,73 @@ TEST_P(ProtoStructValueTest, StringWrapperListHasField) { IsOkAndHolds(Eq(true))); } +TEST_P(ProtoStructValueTest, ListValueListHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_list_value")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.add_repeated_list_value(); + }))); + EXPECT_THAT( + value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_list_value")), + IsOkAndHolds(Eq(true))); +} + +TEST_P(ProtoStructValueTest, StructListHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_struct")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.add_repeated_struct(); + }))); + EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_struct")), + IsOkAndHolds(Eq(true))); +} + +TEST_P(ProtoStructValueTest, ValueListHasField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + EXPECT_THAT( + value_without->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_value")), + IsOkAndHolds(Eq(false))); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.add_repeated_value(); + }))); + EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), + ProtoStructType::FieldId("repeated_value")), + IsOkAndHolds(Eq(true))); +} + TEST_P(ProtoStructValueTest, NullValueGetField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; @@ -1486,7 +1622,7 @@ TEST_P(ProtoStructValueTest, EnumGetField) { EXPECT_EQ(field.As()->number(), 1); } -TEST_P(ProtoStructValueTest, StructGetField) { +TEST_P(ProtoStructValueTest, MessageGetField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -1810,6 +1946,91 @@ TEST_P(ProtoStructValueTest, StringWrapperGetField) { EXPECT_EQ(field.As()->ToString(), "foo"); } +TEST_P(ProtoStructValueTest, StructGetField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("single_struct"))); + ASSERT_TRUE(field->Is()); + EXPECT_TRUE(field->As().empty()); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create( + value_factory, CreateTestMessage([](TestAllTypes& message) { + google::protobuf::Value value_proto; + value_proto.set_bool_value(true); + message.mutable_single_struct()->mutable_fields()->insert( + {"foo", std::move(value_proto)}); + }))); + ASSERT_OK_AND_ASSIGN( + field, value_with->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("single_struct"))); + ASSERT_TRUE(field->Is()); + EXPECT_EQ(field->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT( + field->As().Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoStructValueTest, ListValueGetField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("list_value"))); + ASSERT_TRUE(field->Is()); + EXPECT_TRUE(field->As().empty()); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create( + value_factory, CreateTestMessage([](TestAllTypes& message) { + message.mutable_list_value()->add_values()->set_bool_value(true); + }))); + ASSERT_OK_AND_ASSIGN( + field, value_with->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("list_value"))); + ASSERT_TRUE(field->Is()); + EXPECT_EQ(field->As().size(), 1); + EXPECT_THAT( + field->As().Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoStructValueTest, ValueGetField) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("single_value"))); + EXPECT_TRUE(field->Is()); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.mutable_single_value()->set_bool_value(true); + }))); + EXPECT_THAT(value_with->GetField(StructValue::GetFieldContext(value_factory), + ProtoStructType::FieldId("single_value")), + IsOkAndHolds(ValueOf(value_factory, true))); +} + TEST_P(ProtoStructValueTest, NullValueListGetField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; diff --git a/extensions/protobuf/type.cc b/extensions/protobuf/type.cc index 87d97500a..342cc43ed 100644 --- a/extensions/protobuf/type.cc +++ b/extensions/protobuf/type.cc @@ -23,6 +23,19 @@ namespace cel::extensions { +namespace { + +bool IsJsonMap(const Type& type) { + return type.Is() && type.As().key()->Is() && + type.As().value()->Is(); +} + +bool IsJsonList(const Type& type) { + return type.Is() && type.As().element()->Is(); +} + +} // namespace + absl::StatusOr> ProtoType::Resolve( TypeManager& type_manager, const google::protobuf::EnumDescriptor& descriptor) { CEL_ASSIGN_OR_RETURN(auto type, @@ -50,9 +63,11 @@ absl::StatusOr> ProtoType::Resolve( absl::StrCat("Missing protocol buffer type implementation for \"", descriptor.full_name(), "\"")); } - if (ABSL_PREDICT_FALSE( - !(*type)->Is() && !(*type)->Is() && - !(*type)->Is() && !(*type)->Is())) { + if (ABSL_PREDICT_FALSE(!(*type)->Is() && + !(*type)->Is() && + !(*type)->Is() && + !(*type)->Is() && !IsJsonList(**type) && + !IsJsonMap(**type) && !(*type)->Is())) { return absl::FailedPreconditionError( absl::StrCat("Unexpected protocol buffer type implementation for \"", descriptor.full_name(), "\": ", (*type)->DebugString())); diff --git a/extensions/protobuf/value.cc b/extensions/protobuf/value.cc index ada479ee9..8665a2f8a 100644 --- a/extensions/protobuf/value.cc +++ b/extensions/protobuf/value.cc @@ -14,16 +14,34 @@ #include "extensions/protobuf/value.h" +#include +#include +#include +#include #include +#include +#include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/macros.h" #include "absl/base/optimization.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" +#include "base/handle.h" +#include "base/value.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" #include "extensions/protobuf/internal/time.h" #include "extensions/protobuf/internal/wrappers.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/casts.h" #include "internal/status_macros.h" namespace cel::extensions { @@ -42,6 +60,1335 @@ struct CreateStringValueFromProtoVisitor final { } }; +void AppendJsonValueDebugString(std::string& out, + const google::protobuf::Value& value); + +void AppendJsonValueDebugString(std::string& out, + const google::protobuf::ListValue& value) { + out.push_back('['); + auto current = value.values().begin(); + if (current != value.values().end()) { + AppendJsonValueDebugString(out, *current++); + } + for (; current != value.values().end(); ++current) { + out.append(", "); + AppendJsonValueDebugString(out, *current); + } + out.push_back(']'); +} + +void AppendJsonValueDebugString(std::string& out, + const google::protobuf::Struct& value) { + out.push_back('{'); + std::vector field_names; + field_names.reserve(value.fields_size()); + for (const auto& field : value.fields()) { + field_names.push_back(field.first); + } + std::stable_sort(field_names.begin(), field_names.end()); + auto current = field_names.cbegin(); + if (current != field_names.cend()) { + out.append(StringValue::DebugString(*current)); + out.append(": "); + AppendJsonValueDebugString(out, value.fields().at(*current++)); + for (; current != field_names.cend(); ++current) { + out.append(", "); + out.append(StringValue::DebugString(*current)); + out.append(": "); + AppendJsonValueDebugString(out, value.fields().at(*current)); + } + } + out.push_back('}'); +} + +void AppendJsonValueDebugString(std::string& out, + const google::protobuf::Value& value) { + switch (value.kind_case()) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + out.append(NullValue::DebugString()); + break; + case google::protobuf::Value::kBoolValue: + out.append(BoolValue::DebugString(value.bool_value())); + break; + case google::protobuf::Value::kNumberValue: + out.append(DoubleValue::DebugString(value.number_value())); + break; + case google::protobuf::Value::kStringValue: + out.append(StringValue::DebugString(value.string_value())); + break; + case google::protobuf::Value::kListValue: + AppendJsonValueDebugString(out, value.list_value()); + break; + case google::protobuf::Value::kStructValue: + AppendJsonValueDebugString(out, value.struct_value()); + break; + default: + break; + } +} + +template +absl::StatusOr> CreateMemberJsonValue( + ValueFactory& value_factory, const google::protobuf::ListValue& value, + Owner reference); + +template +absl::StatusOr> CreateMemberJsonValue( + ValueFactory& value_factory, const google::protobuf::Struct& value, + Owner reference); + +template +absl::StatusOr> CreateMemberJsonValue( + ValueFactory& value_factory, const google::protobuf::Value& value, + HandleFromThis&& owner_from_this) { + switch (value.kind_case()) { + case google::protobuf::Value::kNullValue: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValue: + return value_factory.CreateBoolValue(value.bool_value()); + case google::protobuf::Value::kNumberValue: + return value_factory.CreateDoubleValue(value.number_value()); + case google::protobuf::Value::kStringValue: + return value_factory.CreateBorrowedStringValue(owner_from_this(), + value.string_value()); + case google::protobuf::Value::kListValue: + return CreateMemberJsonValue(value_factory, value.list_value(), + owner_from_this()); + case google::protobuf::Value::kStructValue: + return CreateMemberJsonValue(value_factory, value.struct_value(), + owner_from_this()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected google.protobuf.Value kind: %d", value.kind_case())); + } +} + +class StaticProtoJsonListValue : public CEL_LIST_VALUE_CLASS { + public: + StaticProtoJsonListValue(Handle type, + google::protobuf::ListValue value) + : CEL_LIST_VALUE_CLASS(std::move(type)), value_(std::move(value)) {} + + std::string DebugString() const final { + std::string out; + AppendJsonValueDebugString(out, value_); + return out; + } + + size_t size() const final { return value_.values_size(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + return CreateMemberJsonValue( + context.value_factory(), value_.values(index), + [this]() mutable { return owner_from_this(); }); + } + + private: + // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. + internal::TypeInfo TypeId() const final { + return internal::TypeId(); + } + + const google::protobuf::ListValue value_; +}; + +class ArenaStaticProtoJsonListValue : public CEL_LIST_VALUE_CLASS { + public: + ArenaStaticProtoJsonListValue(Handle type, + const google::protobuf::ListValue* value) + : CEL_LIST_VALUE_CLASS(std::move(type)), value_(value) {} + + std::string DebugString() const final { + std::string out; + AppendJsonValueDebugString(out, *value_); + return out; + } + + size_t size() const final { return value_->values_size(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + return CreateMemberJsonValue( + context.value_factory(), value_->values(index), + [this]() mutable { return owner_from_this(); }); + } + + private: + // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. + internal::TypeInfo TypeId() const final { + return internal::TypeId(); + } + + const google::protobuf::ListValue* const value_; +}; + +class StaticProtoJsonMapKeysListValue : public CEL_LIST_VALUE_CLASS { + public: + StaticProtoJsonMapKeysListValue( + Handle type, const google::protobuf::Struct* value, + std::vector> field_names) + : CEL_LIST_VALUE_CLASS(std::move(type)), + value_(value), + field_names_(std::move(field_names)) {} + + std::string DebugString() const final { + std::string out; + AppendJsonValueDebugString(out, *value_); + return out; + } + + size_t size() const final { return field_names_.size(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + return CreateMemberJsonValue( + context.value_factory(), value_->fields().at(field_names_[index]), + [this]() mutable { return owner_from_this(); }); + } + + private: + // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. + internal::TypeInfo TypeId() const final { + return internal::TypeId(); + } + + const google::protobuf::Struct* const value_; + std::vector> field_names_; +}; + +class StaticProtoJsonMapValue : public CEL_MAP_VALUE_CLASS { + public: + StaticProtoJsonMapValue(Handle type, google::protobuf::Struct value) + : CEL_MAP_VALUE_CLASS(std::move(type)), value_(std::move(value)) {} + + std::string DebugString() const final { + std::string out; + AppendJsonValueDebugString(out, value_); + return out; + } + + size_t size() const final { return value_.fields_size(); } + + absl::StatusOr>> Get( + const GetContext& context, const Handle& key) const final { + if (!key->Is()) { + return absl::InvalidArgumentError("expected key to be string value"); + } + auto it = value_.fields().find(key->As().ToString()); + if (it == value_.fields().end()) { + return absl::nullopt; + } + return CreateMemberJsonValue( + context.value_factory(), it->second, + [this]() mutable { return owner_from_this(); }); + } + + absl::StatusOr Has(const HasContext& context, + const Handle& key) const final { + if (!key->Is()) { + return absl::InvalidArgumentError("expected key to be string value"); + } + return value_.fields().contains(key->As().ToString()); + } + + absl::StatusOr> ListKeys( + const ListKeysContext& context) const final { + CEL_ASSIGN_OR_RETURN( + auto list_type, + context.value_factory().type_factory().CreateListType(type()->key())); + std::vector> field_names( + Allocator(context.value_factory().memory_manager())); + field_names.reserve(value_.fields_size()); + for (const auto& field : value_.fields()) { + field_names.push_back(field.first); + } + return context.value_factory() + .CreateBorrowedListValue( + owner_from_this(), std::move(list_type), &value_, + std::move(field_names)); + } + + private: + // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. + internal::TypeInfo TypeId() const override { + return internal::TypeId(); + } + + const google::protobuf::Struct value_; +}; + +class ArenaStaticProtoJsonMapValue : public CEL_MAP_VALUE_CLASS { + public: + ArenaStaticProtoJsonMapValue(Handle type, + const google::protobuf::Struct* value) + : CEL_MAP_VALUE_CLASS(std::move(type)), value_(value) {} + + std::string DebugString() const final { + std::string out; + AppendJsonValueDebugString(out, *value_); + return out; + } + + size_t size() const final { return value_->fields_size(); } + + absl::StatusOr>> Get( + const GetContext& context, const Handle& key) const final { + if (!key->Is()) { + return absl::InvalidArgumentError("expected key to be string value"); + } + auto it = value_->fields().find(key->As().ToString()); + if (it == value_->fields().end()) { + return absl::nullopt; + } + return CreateMemberJsonValue( + context.value_factory(), it->second, + [this]() mutable { return owner_from_this(); }); + } + + absl::StatusOr Has(const HasContext& context, + const Handle& key) const final { + if (!key->Is()) { + return absl::InvalidArgumentError("expected key to be string value"); + } + return value_->fields().contains(key->As().ToString()); + } + + absl::StatusOr> ListKeys( + const ListKeysContext& context) const final { + CEL_ASSIGN_OR_RETURN( + auto list_type, + context.value_factory().type_factory().CreateListType(type()->key())); + std::vector> field_names( + Allocator(context.value_factory().memory_manager())); + field_names.reserve(value_->fields_size()); + for (const auto& field : value_->fields()) { + field_names.push_back(field.first); + } + return context.value_factory() + .CreateBorrowedListValue( + owner_from_this(), std::move(list_type), value_, + std::move(field_names)); + } + + private: + // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. + internal::TypeInfo TypeId() const final { + return internal::TypeId(); + } + + const google::protobuf::Struct* const value_; +}; + +template +absl::StatusOr> CreateMemberJsonValue( + ValueFactory& value_factory, const google::protobuf::ListValue& value, + Owner reference) { + CEL_ASSIGN_OR_RETURN(auto list_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + return value_factory.CreateBorrowedListValue( + std::move(reference), std::move(list_type), &value); +} + +template +absl::StatusOr> CreateMemberJsonValue( + ValueFactory& value_factory, const google::protobuf::Struct& value, + Owner reference) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + return value_factory.CreateBorrowedMapValue( + std::move(reference), std::move(map_type), &value); +} + +} // namespace + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, google::protobuf::ListValue value) { + CEL_ASSIGN_OR_RETURN(auto list_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + *arena_value = std::move(value); + return value_factory.CreateListValue( + std::move(list_type), arena_value); + } + } + return value_factory.CreateListValue( + std::move(list_type), std::move(value)); +} + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, + std::unique_ptr value) { + CEL_ASSIGN_OR_RETURN(auto list_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + arena_value->Swap(value.get()); + return value_factory.CreateListValue( + std::move(list_type), arena_value); + } + } + return value_factory.CreateListValue( + std::move(list_type), std::move(*value)); +} + +absl::StatusOr> ProtoValue::CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::ListValue& value) { + CEL_ASSIGN_OR_RETURN(auto list_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + return value_factory.CreateBorrowedListValue( + std::move(owner), std::move(list_type), &value); +} + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, google::protobuf::Struct value) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + *arena_value = std::move(value); + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + return value_factory.CreateMapValue( + std::move(map_type), std::move(value)); +} + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, + std::unique_ptr value) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + *arena_value = std::move(*value); + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + return value_factory.CreateMapValue( + std::move(map_type), std::move(*value)); +} + +absl::StatusOr> ProtoValue::CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Struct& value) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + return value_factory.CreateBorrowedMapValue( + std::move(owner), std::move(map_type), &value); +} + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, + std::unique_ptr value) { + switch (value->kind_case()) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValue: + return value_factory.CreateBoolValue(value->bool_value()); + case google::protobuf::Value::kNumberValue: + return value_factory.CreateDoubleValue(value->number_value()); + case google::protobuf::Value::kStringValue: + return value_factory.CreateUncheckedStringValue( + std::move(*value->mutable_string_value())); + case google::protobuf::Value::kListValue: + return Create(value_factory, + absl::WrapUnique(value->release_list_value())); + case google::protobuf::Value::kStructValue: + return Create(value_factory, + absl::WrapUnique(value->release_struct_value())); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected google.protobuf.Value kind: %d", value->kind_case())); + } +} + +absl::StatusOr> ProtoValue::Create( + ValueFactory& value_factory, google::protobuf::Value value) { + switch (value.kind_case()) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValue: + return value_factory.CreateBoolValue(value.bool_value()); + case google::protobuf::Value::kNumberValue: + return value_factory.CreateDoubleValue(value.number_value()); + case google::protobuf::Value::kStringValue: + return value_factory.CreateUncheckedStringValue(value.string_value()); + case google::protobuf::Value::kListValue: + return Create(value_factory, std::move(*value.mutable_list_value())); + case google::protobuf::Value::kStructValue: + return Create(value_factory, std::move(*value.mutable_struct_value())); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected google.protobuf.Value kind: %d", value.kind_case())); + } +} + +absl::StatusOr> ProtoValue::CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Value& value) { + switch (value.kind_case()) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValue: + return value_factory.CreateBoolValue(value.bool_value()); + case google::protobuf::Value::kNumberValue: + return value_factory.CreateDoubleValue(value.number_value()); + case google::protobuf::Value::kStringValue: + return value_factory.CreateBorrowedStringValue(std::move(owner), + value.string_value()); + case google::protobuf::Value::kListValue: + return CreateBorrowed(std::move(owner), value_factory, + value.list_value()); + case google::protobuf::Value::kStructValue: + return CreateBorrowed(std::move(owner), value_factory, + value.struct_value()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected google.protobuf.Value kind: %d", value.kind_case())); + } +} + +namespace { + +using DynamicMessageCopyConverter = + absl::StatusOr> (*)(ValueFactory&, const google::protobuf::Message&); +using DynamicMessageMoveConverter = + absl::StatusOr> (*)(ValueFactory&, google::protobuf::Message&&); +using DynamicMessageBorrowConverter = absl::StatusOr> (*)( + Owner&, ValueFactory&, const google::protobuf::Message&); + +using DynamicMessageConverter = + std::tuple; + +absl::StatusOr> DurationMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto duration, + protobuf_internal::AbslDurationFromDurationProto(value)); + return value_factory.CreateUncheckedDurationValue(duration); +} + +absl::StatusOr> DurationMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto duration, + protobuf_internal::AbslDurationFromDurationProto(value)); + return value_factory.CreateUncheckedDurationValue(duration); +} + +absl::StatusOr> DurationMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto duration, + protobuf_internal::AbslDurationFromDurationProto(value)); + return value_factory.CreateUncheckedDurationValue(duration); +} + +absl::StatusOr> TimestampMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto time, + protobuf_internal::AbslTimeFromTimestampProto(value)); + return value_factory.CreateUncheckedTimestampValue(time); +} + +absl::StatusOr> TimestampMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto time, + protobuf_internal::AbslTimeFromTimestampProto(value)); + return value_factory.CreateUncheckedTimestampValue(time); +} + +absl::StatusOr> TimestampMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto time, + protobuf_internal::AbslTimeFromTimestampProto(value)); + return value_factory.CreateUncheckedTimestampValue(time); +} + +absl::StatusOr> BoolValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBoolValueProto(value)); + return value_factory.CreateBoolValue(wrapped); +} + +absl::StatusOr> BoolValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBoolValueProto(value)); + return value_factory.CreateBoolValue(wrapped); +} + +absl::StatusOr> BoolValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBoolValueProto(value)); + return value_factory.CreateBoolValue(wrapped); +} + +absl::StatusOr> BytesValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBytesValueProto(value)); + return value_factory.CreateBytesValue(std::move(wrapped)); +} + +absl::StatusOr> BytesValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBytesValueProto(value)); + return value_factory.CreateBytesValue(std::move(wrapped)); +} + +absl::StatusOr> BytesValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBytesValueProto(value)); + return value_factory.CreateBytesValue(std::move(wrapped)); +} + +absl::StatusOr> FloatValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapFloatValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> FloatValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapFloatValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> FloatValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapFloatValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> DoubleValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapDoubleValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> DoubleValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapDoubleValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> DoubleValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapDoubleValueProto(value)); + return value_factory.CreateDoubleValue(wrapped); +} + +absl::StatusOr> Int32ValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt32ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> Int32ValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt32ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> Int32ValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt32ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> Int64ValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt64ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> Int64ValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt64ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> Int64ValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapInt64ValueProto(value)); + return value_factory.CreateIntValue(wrapped); +} + +absl::StatusOr> StringValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapStringValueProto(value)); + return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, + std::move(wrapped)); +} + +absl::StatusOr> StringValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapStringValueProto(value)); + return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, + std::move(wrapped)); +} + +absl::StatusOr> StringValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapStringValueProto(value)); + return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, + std::move(wrapped)); +} + +absl::StatusOr> UInt32ValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt32ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> UInt32ValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt32ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> UInt32ValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt32ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> UInt64ValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt64ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> UInt64ValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt64ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> UInt64ValueMessageBorrowConverter( + Owner& owner ABSL_ATTRIBUTE_UNUSED, ValueFactory& value_factory, + const google::protobuf::Message& value) { + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUInt64ValueProto(value)); + return value_factory.CreateUintValue(wrapped); +} + +absl::StatusOr> StructMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + if (value.GetDescriptor() == google::protobuf::Struct::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.Struct"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + google::protobuf::Struct parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> StructMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + if (value.GetDescriptor() == google::protobuf::Struct::descriptor()) { + return ProtoValue::Create( + value_factory, + std::move(cel::internal::down_cast(value))); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.Struct"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + google::protobuf::Struct parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> StructMessageBorrowConverter( + Owner& owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + if (value.GetDescriptor() == google::protobuf::Struct::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.Struct"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + google::protobuf::Struct parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> StructMessageOwnConverter( + ValueFactory& value_factory, std::unique_ptr value) { + if (value->GetDescriptor() == google::protobuf::Struct::descriptor()) { + return ProtoValue::Create( + value_factory, + std::move(cel::internal::down_cast(*value))); + } + std::string serialized; + if (!value->SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.Struct"); + } + value.reset(); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateMapType( + value_factory.type_factory().GetStringType(), + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return value_factory.CreateMapValue( + std::move(map_type), arena_value); + } + } + google::protobuf::Struct parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.Struct"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> ListValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + if (value.GetDescriptor() == google::protobuf::ListValue::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.ListValue"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return value_factory.CreateListValue( + std::move(map_type), arena_value); + } + } + google::protobuf::ListValue parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> ListValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + if (value.GetDescriptor() == google::protobuf::ListValue::descriptor()) { + return ProtoValue::Create( + value_factory, + std::move( + cel::internal::down_cast(value))); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.ListValue"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return value_factory.CreateListValue( + std::move(map_type), arena_value); + } + } + google::protobuf::ListValue parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> ListValueMessageBorrowConverter( + Owner& owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + if (value.GetDescriptor() == google::protobuf::ListValue::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + std::string serialized; + if (!value.SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.ListValue"); + } + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto map_type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return value_factory.CreateListValue( + std::move(map_type), arena_value); + } + } + google::protobuf::ListValue parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> ListValueMessageOwnConverter( + ValueFactory& value_factory, std::unique_ptr value) { + if (value->GetDescriptor() == google::protobuf::ListValue::descriptor()) { + return ProtoValue::Create( + value_factory, + std::move( + cel::internal::down_cast(*value))); + } + std::string serialized; + if (!value->SerializePartialToString(&serialized)) { + return absl::InternalError("failed to serialize google.protobuf.ListValue"); + } + value.reset(); + if (ProtoMemoryManager::Is(value_factory.memory_manager())) { + auto* arena = + ProtoMemoryManager::CastToProtoArena(value_factory.memory_manager()); + if (arena != nullptr) { + CEL_ASSIGN_OR_RETURN(auto type, + value_factory.type_factory().CreateListType( + value_factory.type_factory().GetDynType())); + auto* arena_value = + google::protobuf::Arena::CreateMessage(arena); + if (!arena_value->ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return value_factory.CreateListValue( + std::move(type), arena_value); + } + } + google::protobuf::ListValue parsed; + if (!parsed.ParsePartialFromString(serialized)) { + return absl::InternalError("failed to parse google.protobuf.ListValue"); + } + return ProtoValue::Create(value_factory, std::move(parsed)); +} + +absl::StatusOr> ValueMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + const auto* desc = value.GetDescriptor(); + if (desc == google::protobuf::Value::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + const auto* oneof_desc = desc->FindOneofByName("kind"); + if (ABSL_PREDICT_FALSE(oneof_desc == nullptr)) { + return absl::InvalidArgumentError( + "oneof descriptor missing for google.protobuf.Value"); + } + const auto* reflect = value.GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + "reflection missing for google.protobuf.Value"); + } + const auto* field_desc = reflect->GetOneofFieldDescriptor(value, oneof_desc); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return value_factory.GetNullValue(); + } + switch (field_desc->number()) { + case google::protobuf::Value::kNullValueFieldNumber: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValueFieldNumber: + return value_factory.CreateBoolValue(reflect->GetBool(value, field_desc)); + case google::protobuf::Value::kNumberValueFieldNumber: + return value_factory.CreateDoubleValue( + reflect->GetDouble(value, field_desc)); + case google::protobuf::Value::kStringValueFieldNumber: + return value_factory.CreateStringValue( + reflect->GetStringView(value, field_desc)); + case google::protobuf::Value::kListValueFieldNumber: + return ListValueMessageCopyConverter( + value_factory, reflect->GetMessage(value, field_desc)); + case google::protobuf::Value::kStructValueFieldNumber: + return StructMessageCopyConverter(value_factory, + reflect->GetMessage(value, field_desc)); + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected oneof field set for google.protobuf.Value: ", + field_desc->number())); + } +} + +absl::StatusOr> ValueMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + const auto* desc = value.GetDescriptor(); + if (desc == google::protobuf::Value::descriptor()) { + return ProtoValue::Create( + value_factory, + std::move(cel::internal::down_cast(value))); + } + const auto* oneof_desc = desc->FindOneofByName("kind"); + if (ABSL_PREDICT_FALSE(oneof_desc == nullptr)) { + return absl::InvalidArgumentError( + "oneof descriptor missing for google.protobuf.Value"); + } + const auto* reflect = value.GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + "reflection missing for google.protobuf.Value"); + } + const auto* field_desc = reflect->GetOneofFieldDescriptor(value, oneof_desc); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return value_factory.GetNullValue(); + } + switch (field_desc->number()) { + case google::protobuf::Value::kNullValueFieldNumber: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValueFieldNumber: + return value_factory.CreateBoolValue(reflect->GetBool(value, field_desc)); + case google::protobuf::Value::kNumberValueFieldNumber: + return value_factory.CreateDoubleValue( + reflect->GetDouble(value, field_desc)); + case google::protobuf::Value::kStringValueFieldNumber: + return value_factory.CreateStringValue( + reflect->GetStringView(value, field_desc)); + case google::protobuf::Value::kListValueFieldNumber: + return ListValueMessageMoveConverter( + value_factory, + std::move(*reflect->MutableMessage(&value, field_desc))); + case google::protobuf::Value::kStructValueFieldNumber: + return StructMessageMoveConverter( + value_factory, + std::move(*reflect->MutableMessage(&value, field_desc))); + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected oneof field set for google.protobuf.Value: ", + field_desc->number())); + } +} + +absl::StatusOr> ValueMessageBorrowConverter( + Owner& owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + const auto* desc = value.GetDescriptor(); + if (desc == google::protobuf::Value::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + const auto* oneof_desc = desc->FindOneofByName("kind"); + if (ABSL_PREDICT_FALSE(oneof_desc == nullptr)) { + return absl::InvalidArgumentError( + "oneof descriptor missing for google.protobuf.Value"); + } + const auto* reflect = value.GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + "reflection missing for google.protobuf.Value"); + } + const auto* field_desc = reflect->GetOneofFieldDescriptor(value, oneof_desc); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return value_factory.GetNullValue(); + } + switch (field_desc->number()) { + case google::protobuf::Value::kNullValueFieldNumber: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValueFieldNumber: + return value_factory.CreateBoolValue(reflect->GetBool(value, field_desc)); + case google::protobuf::Value::kNumberValueFieldNumber: + return value_factory.CreateDoubleValue( + reflect->GetDouble(value, field_desc)); + case google::protobuf::Value::kStringValueFieldNumber: + return value_factory.CreateBorrowedStringValue( + std::move(owner), reflect->GetStringView(value, field_desc)); + case google::protobuf::Value::kListValueFieldNumber: + return ListValueMessageBorrowConverter( + owner, value_factory, reflect->GetMessage(value, field_desc)); + case google::protobuf::Value::kStructValueFieldNumber: + return StructMessageBorrowConverter( + owner, value_factory, reflect->GetMessage(value, field_desc)); + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected oneof field set for google.protobuf.Value: ", + field_desc->number())); + } +} + +absl::StatusOr> ValueMessageOwnConverter( + ValueFactory& value_factory, std::unique_ptr value) { + const auto* desc = value->GetDescriptor(); + if (desc == google::protobuf::Value::descriptor()) { + return ProtoValue::Create( + value_factory, + absl::WrapUnique(cel::internal::down_cast( + value.release()))); + } + const auto* oneof_desc = desc->FindOneofByName("kind"); + if (ABSL_PREDICT_FALSE(oneof_desc == nullptr)) { + return absl::InvalidArgumentError( + "oneof descriptor missing for google.protobuf.Value"); + } + const auto* reflect = value->GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + "reflection missing for google.protobuf.Value"); + } + const auto* field_desc = reflect->GetOneofFieldDescriptor(*value, oneof_desc); + if (ABSL_PREDICT_FALSE(field_desc == nullptr)) { + return value_factory.GetNullValue(); + } + switch (field_desc->number()) { + case google::protobuf::Value::kNullValueFieldNumber: + return value_factory.GetNullValue(); + case google::protobuf::Value::kBoolValueFieldNumber: + return value_factory.CreateBoolValue( + reflect->GetBool(*value, field_desc)); + case google::protobuf::Value::kNumberValueFieldNumber: + return value_factory.CreateDoubleValue( + reflect->GetDouble(*value, field_desc)); + case google::protobuf::Value::kStringValueFieldNumber: + return value_factory.CreateStringValue( + reflect->GetStringView(*value, field_desc)); + case google::protobuf::Value::kListValueFieldNumber: + return ListValueMessageCopyConverter( + value_factory, reflect->GetMessage(*value, field_desc)); + case google::protobuf::Value::kStructValueFieldNumber: + return StructMessageCopyConverter( + value_factory, reflect->GetMessage(*value, field_desc)); + default: + return absl::InvalidArgumentError( + absl::StrCat("unexpected oneof field set for google.protobuf.Value: ", + field_desc->number())); + } +} + +ABSL_CONST_INIT absl::once_flag proto_value_once; +ABSL_CONST_INIT DynamicMessageConverter dynamic_message_converters[] = { + {"google.protobuf.Duration", DurationMessageCopyConverter, + DurationMessageMoveConverter, DurationMessageBorrowConverter}, + {"google.protobuf.Timestamp", TimestampMessageCopyConverter, + TimestampMessageMoveConverter, TimestampMessageBorrowConverter}, + {"google.protobuf.BoolValue", BoolValueMessageCopyConverter, + BoolValueMessageMoveConverter, BoolValueMessageBorrowConverter}, + {"google.protobuf.BytesValue", BytesValueMessageCopyConverter, + BytesValueMessageMoveConverter, BytesValueMessageBorrowConverter}, + {"google.protobuf.FloatValue", FloatValueMessageCopyConverter, + FloatValueMessageMoveConverter, FloatValueMessageBorrowConverter}, + {"google.protobuf.DoubleValue", DoubleValueMessageCopyConverter, + DoubleValueMessageMoveConverter, DoubleValueMessageBorrowConverter}, + {"google.protobuf.Int32Value", Int32ValueMessageCopyConverter, + Int32ValueMessageMoveConverter, Int32ValueMessageBorrowConverter}, + {"google.protobuf.Int64Value", Int64ValueMessageCopyConverter, + Int64ValueMessageMoveConverter, Int64ValueMessageBorrowConverter}, + {"google.protobuf.StringValue", StringValueMessageCopyConverter, + StringValueMessageMoveConverter, StringValueMessageBorrowConverter}, + {"google.protobuf.UInt32Value", UInt32ValueMessageCopyConverter, + UInt32ValueMessageMoveConverter, UInt32ValueMessageBorrowConverter}, + {"google.protobuf.UInt64Value", UInt64ValueMessageCopyConverter, + UInt64ValueMessageMoveConverter, UInt64ValueMessageBorrowConverter}, + {"google.protobuf.Struct", StructMessageCopyConverter, + StructMessageMoveConverter, StructMessageBorrowConverter}, + {"google.protobuf.ListValue", ListValueMessageCopyConverter, + ListValueMessageMoveConverter, ListValueMessageBorrowConverter}, + {"google.protobuf.Value", ValueMessageCopyConverter, + ValueMessageMoveConverter, ValueMessageBorrowConverter}, +}; + +DynamicMessageConverter* dynamic_message_converters_begin() { + return dynamic_message_converters; +} + +DynamicMessageConverter* dynamic_message_converters_end() { + return dynamic_message_converters + + ABSL_ARRAYSIZE(dynamic_message_converters); +} + +const DynamicMessageConverter* dynamic_message_converters_cbegin() { + return dynamic_message_converters_begin(); +} + +const DynamicMessageConverter* dynamic_message_converters_cend() { + return dynamic_message_converters_end(); +} + +struct DynamicMessageConverterComparer { + bool operator()(const DynamicMessageConverter& lhs, + absl::string_view rhs) const { + return std::get(lhs) < rhs; + } + + bool operator()(absl::string_view lhs, + const DynamicMessageConverter& rhs) const { + return lhs < std::get(rhs); + } +}; + +void InitializeProtoValue() { + std::stable_sort(dynamic_message_converters_begin(), + dynamic_message_converters_end(), + [](const DynamicMessageConverter& lhs, + const DynamicMessageConverter& rhs) { + return std::get(lhs) < + std::get(rhs); + }); +} + } // namespace absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, @@ -51,53 +1398,39 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, return absl::InternalError("protocol buffer message missing descriptor"); } const auto& type_name = desc->full_name(); - if (type_name == "google.protobuf.Duration") { - CEL_ASSIGN_OR_RETURN( - auto duration, protobuf_internal::AbslDurationFromDurationProto(value)); - return value_factory.CreateUncheckedDurationValue(duration); - } - if (type_name == "google.protobuf.Timestamp") { - CEL_ASSIGN_OR_RETURN(auto time, - protobuf_internal::AbslTimeFromTimestampProto(value)); - return value_factory.CreateUncheckedTimestampValue(time); - } - if (type_name == "google.protobuf.BoolValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBoolValueProto(value)); - return value_factory.CreateBoolValue(wrapped); - } - if (type_name == "google.protobuf.BytesValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBytesValueProto(value)); - return value_factory.CreateBytesValue(std::move(wrapped)); - } - if (type_name == "google.protobuf.FloatValue" || - type_name == "google.protobuf.DoubleValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapDoubleValueProto(value)); - return value_factory.CreateDoubleValue(wrapped); - } - if (type_name == "google.protobuf.Int32Value" || - type_name == "google.protobuf.Int64Value") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapIntValueProto(value)); - return value_factory.CreateIntValue(wrapped); - } - if (type_name == "google.protobuf.StringValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapStringValueProto(value)); - return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, - std::move(wrapped)); - } - if (type_name == "google.protobuf.UInt32Value" || - type_name == "google.protobuf.UInt64Value") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapUIntValueProto(value)); - return value_factory.CreateUintValue(wrapped); + absl::call_once(proto_value_once, InitializeProtoValue); + auto converter = std::lower_bound( + dynamic_message_converters_cbegin(), dynamic_message_converters_cend(), + type_name, DynamicMessageConverterComparer{}); + if (converter != dynamic_message_converters_cend() && + std::get(*converter) == type_name) { + return std::get(*converter)(value_factory, + value); } return ProtoStructValue::Create(value_factory, value); } +absl::StatusOr> ProtoValue::CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + const auto* desc = value.GetDescriptor(); + if (ABSL_PREDICT_FALSE(desc == nullptr)) { + return absl::InternalError("protocol buffer message missing descriptor"); + } + const auto& type_name = desc->full_name(); + absl::call_once(proto_value_once, InitializeProtoValue); + auto converter = std::lower_bound( + dynamic_message_converters_cbegin(), dynamic_message_converters_cend(), + type_name, DynamicMessageConverterComparer{}); + if (converter != dynamic_message_converters_cend() && + std::get(*converter) == type_name) { + return std::get(*converter)( + owner, value_factory, value); + } + return ProtoStructValue::CreateBorrowed(std::move(owner), value_factory, + value); +} + absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, google::protobuf::Message&& value) { const auto* desc = value.GetDescriptor(); @@ -105,49 +1438,14 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, return absl::InternalError("protocol buffer message missing descriptor"); } const auto& type_name = desc->full_name(); - if (type_name == "google.protobuf.Duration") { - CEL_ASSIGN_OR_RETURN( - auto duration, protobuf_internal::AbslDurationFromDurationProto(value)); - return value_factory.CreateUncheckedDurationValue(duration); - } - if (type_name == "google.protobuf.Timestamp") { - CEL_ASSIGN_OR_RETURN(auto time, - protobuf_internal::AbslTimeFromTimestampProto(value)); - return value_factory.CreateUncheckedTimestampValue(time); - } - if (type_name == "google.protobuf.BoolValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBoolValueProto(value)); - return value_factory.CreateBoolValue(wrapped); - } - if (type_name == "google.protobuf.BytesValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBytesValueProto(value)); - return value_factory.CreateBytesValue(std::move(wrapped)); - } - if (type_name == "google.protobuf.FloatValue" || - type_name == "google.protobuf.DoubleValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapDoubleValueProto(value)); - return value_factory.CreateDoubleValue(wrapped); - } - if (type_name == "google.protobuf.Int32Value" || - type_name == "google.protobuf.Int64Value") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapIntValueProto(value)); - return value_factory.CreateIntValue(wrapped); - } - if (type_name == "google.protobuf.StringValue") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapStringValueProto(value)); - return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, - std::move(wrapped)); - } - if (type_name == "google.protobuf.UInt32Value" || - type_name == "google.protobuf.UInt64Value") { - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapUIntValueProto(value)); - return value_factory.CreateUintValue(wrapped); + absl::call_once(proto_value_once, InitializeProtoValue); + auto converter = std::lower_bound( + dynamic_message_converters_cbegin(), dynamic_message_converters_cend(), + type_name, DynamicMessageConverterComparer{}); + if (converter != dynamic_message_converters_cend() && + std::get(*converter) == type_name) { + return std::get(*converter)(value_factory, + std::move(value)); } return ProtoStructValue::Create(value_factory, std::move(value)); } @@ -169,4 +1467,41 @@ absl::StatusOr> ProtoValue::Create( } } +namespace protobuf_internal { + +absl::StatusOr> CreateBorrowedListValue( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + return ListValueMessageBorrowConverter(owner, value_factory, value); +} + +absl::StatusOr> CreateBorrowedStruct( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + return StructMessageBorrowConverter(owner, value_factory, value); +} + +absl::StatusOr> CreateBorrowedValue( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + return ValueMessageBorrowConverter(owner, value_factory, value); +} + +absl::StatusOr> CreateListValue( + ValueFactory& value_factory, std::unique_ptr value) { + return ListValueMessageOwnConverter(value_factory, std::move(value)); +} + +absl::StatusOr> CreateStruct( + ValueFactory& value_factory, std::unique_ptr value) { + return StructMessageOwnConverter(value_factory, std::move(value)); +} + +absl::StatusOr> CreateValue( + ValueFactory& value_factory, std::unique_ptr value) { + return ValueMessageOwnConverter(value_factory, std::move(value)); +} + +} // namespace protobuf_internal + } // namespace cel::extensions diff --git a/extensions/protobuf/value.h b/extensions/protobuf/value.h index 5366c46a9..be70fec67 100644 --- a/extensions/protobuf/value.h +++ b/extensions/protobuf/value.h @@ -15,15 +15,19 @@ #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ +#include #include #include #include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/attributes.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "base/handle.h" +#include "base/owner.h" #include "base/value.h" #include "base/value_factory.h" #include "base/values/duration_value.h" @@ -109,6 +113,15 @@ class ProtoValue final { template using NotWrapperMessage = std::negation>; + template + using JsonMessage = std::disjunction< + std::is_same>, + std::is_same>, + std::is_same>>; + + template + using NotJsonMessage = std::negation>; + public: // Create a new EnumValue from a generated protocol buffer enum. template @@ -134,12 +147,24 @@ class ProtoValue final { template static std::enable_if_t< std::conjunction_v, NotDurationMessage, - NotTimestampMessage, NotWrapperMessage>, + NotTimestampMessage, NotWrapperMessage, + NotJsonMessage>, absl::StatusOr>> Create(ValueFactory& value_factory, T&& value) { return ProtoStructValue::Create(value_factory, std::forward(value)); } + template + static std::enable_if_t< + std::conjunction_v, NotDurationMessage, + NotTimestampMessage, NotWrapperMessage, + NotJsonMessage>, + absl::StatusOr>> + CreateBorrowed(ValueFactory& value_factory, + const T& value ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return ProtoStructValue::Create(value_factory, value); + } + // Create a new DurationValue from google.protobuf.Duration. static absl::StatusOr> Create( ValueFactory& value_factory, const google::protobuf::Duration& value) { @@ -209,10 +234,48 @@ class ProtoValue final { return value_factory.CreateUintValue(value.value()); } + static absl::StatusOr> Create( + ValueFactory& value_factory, google::protobuf::ListValue value); + + static absl::StatusOr> Create( + ValueFactory& value_factory, + std::unique_ptr value); + + static absl::StatusOr> CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static absl::StatusOr> Create( + ValueFactory& value_factory, google::protobuf::Struct value); + + static absl::StatusOr> Create( + ValueFactory& value_factory, + std::unique_ptr value); + + static absl::StatusOr> CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Struct& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static absl::StatusOr> Create(ValueFactory& value_factory, + google::protobuf::Value value); + + static absl::StatusOr> Create( + ValueFactory& value_factory, + std::unique_ptr value); + + static absl::StatusOr> CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Create a new Value from a protocol buffer message. static absl::StatusOr> Create(ValueFactory& value_factory, const google::protobuf::Message& value); + // Create a new Value from a protocol buffer message. + static absl::StatusOr> CreateBorrowed( + Owner owner, ValueFactory& value_factory, + const google::protobuf::Message& value ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Create a new Value from a protocol buffer message. static absl::StatusOr> Create(ValueFactory& value_factory, google::protobuf::Message&& value); diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc index 2814097fc..a911de316 100644 --- a/extensions/protobuf/value_test.cc +++ b/extensions/protobuf/value_test.cc @@ -14,6 +14,8 @@ #include "extensions/protobuf/value.h" +#include + #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "base/internal/memory_manager_testing.h" @@ -230,6 +232,371 @@ TEST_P(ProtoValueTest, DynamicNullValue) { EXPECT_TRUE(value->Is()); } +TEST_P(ProtoValueTest, StaticValueNullValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + value_proto->set_null_value(google::protobuf::NULL_VALUE); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticLValueValueNullValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_null_value(google::protobuf::NULL_VALUE); + EXPECT_THAT(ProtoValue::Create(value_factory, value_proto), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticRValueValueNullValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_null_value(google::protobuf::NULL_VALUE); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticValueBoolValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + value_proto->set_bool_value(true); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticLValueValueBoolValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_bool_value(true); + EXPECT_THAT(ProtoValue::Create(value_factory, value_proto), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticRValueValueBoolValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_bool_value(true); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticValueNumberValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + value_proto->set_number_value(1.0); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, 1.0))); +} + +TEST_P(ProtoValueTest, StaticLValueValueNumberValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_number_value(1.0); + EXPECT_THAT(ProtoValue::Create(value_factory, value_proto), + IsOkAndHolds(ValueOf(value_factory, 1.0))); +} + +TEST_P(ProtoValueTest, StaticRValueValueNumberValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_number_value(1.0); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, 1.0))); +} + +TEST_P(ProtoValueTest, StaticValueStringValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + value_proto->set_string_value("foo"); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, "foo"))); +} + +TEST_P(ProtoValueTest, StaticLValueValueStringValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_string_value("foo"); + EXPECT_THAT(ProtoValue::Create(value_factory, value_proto), + IsOkAndHolds(ValueOf(value_factory, "foo"))); +} + +TEST_P(ProtoValueTest, StaticRValueValueStringValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.set_string_value("foo"); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory, "foo"))); +} + +TEST_P(ProtoValueTest, StaticValueListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + value_proto->mutable_list_value()->add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(value_proto))); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + EXPECT_THAT( + value->As().Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticLValueValueListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.mutable_list_value()->add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN(auto value, + ProtoValue::Create(value_factory, value_proto)); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + EXPECT_THAT( + value->As().Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticRValueValueListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + value_proto.mutable_list_value()->add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(value_proto))); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + EXPECT_THAT( + value->As().Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticValueStructValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + auto value_proto = std::make_unique(); + value_proto->mutable_struct_value()->mutable_fields()->insert( + {"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(value_proto))); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT( + value->As().Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoValueTest, StaticLValueValueStructValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + google::protobuf::Value value_proto; + value_proto.mutable_struct_value()->mutable_fields()->insert( + {"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN(auto value, + ProtoValue::Create(value_factory, value_proto)); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT( + value->As().Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoValueTest, StaticRValueValueStructValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + google::protobuf::Value value_proto; + value_proto.mutable_struct_value()->mutable_fields()->insert( + {"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(value_proto))); + EXPECT_TRUE(value->Is()); + EXPECT_EQ(value->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT( + value->As().Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoValueTest, StaticValueUnset) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto value_proto = std::make_unique(); + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticLValueValueUnset) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + EXPECT_THAT(ProtoValue::Create(value_factory, value_proto), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticRValueValueUnset) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value value_proto; + EXPECT_THAT(ProtoValue::Create(value_factory, std::move(value_proto)), + IsOkAndHolds(ValueOf(value_factory))); +} + +TEST_P(ProtoValueTest, StaticListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + auto list_value_proto = std::make_unique(); + list_value_proto->add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoValue::Create(value_factory, std::move(list_value_proto))); + EXPECT_EQ(value->size(), 1); + EXPECT_THAT(value->Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticLValueListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::ListValue list_value_proto; + list_value_proto.add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN(auto value, + ProtoValue::Create(value_factory, list_value_proto)); + EXPECT_EQ(value->size(), 1); + EXPECT_THAT(value->Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticRValueListValue) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::ListValue list_value_proto; + list_value_proto.add_values()->set_bool_value(true); + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoValue::Create(value_factory, std::move(list_value_proto))); + EXPECT_EQ(value->size(), 1); + EXPECT_THAT(value->Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); +} + +TEST_P(ProtoValueTest, StaticStruct) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + auto struct_proto = std::make_unique(); + struct_proto->mutable_fields()->insert({"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(struct_proto))); + EXPECT_EQ(value->size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT(value->Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoValueTest, StaticLValueStruct) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + google::protobuf::Struct struct_proto; + struct_proto.mutable_fields()->insert({"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN(auto value, + ProtoValue::Create(value_factory, struct_proto)); + EXPECT_EQ(value->size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT(value->Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + +TEST_P(ProtoValueTest, StaticRValueStruct) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + google::protobuf::Value bool_value_proto; + bool_value_proto.set_bool_value(true); + google::protobuf::Struct struct_proto; + struct_proto.mutable_fields()->insert({"foo", bool_value_proto}); + ASSERT_OK_AND_ASSIGN( + auto value, ProtoValue::Create(value_factory, std::move(struct_proto))); + EXPECT_EQ(value->size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT(value->Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); +} + TEST_P(ProtoValueTest, StaticWrapperTypes) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; From 95a517d769af2256879451ebe6bc1d11125fce10 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 4 May 2023 17:17:49 +0000 Subject: [PATCH 241/303] Cleanup `MemoryManager` and drop `ManagedMemory` PiperOrigin-RevId: 529444909 --- base/BUILD | 31 +- base/allocator.h | 114 +----- base/handle.h | 23 +- base/internal/BUILD | 12 - base/internal/managed_memory.cc | 77 ---- base/internal/managed_memory.h | 372 ----------------- base/managed_memory.h | 36 -- base/memory.h | 377 ++++++++++++++++++ base/memory_manager.h | 133 +----- base/memory_manager_test.cc | 34 +- eval/eval/BUILD | 4 + eval/eval/attribute_utility.cc | 25 +- eval/eval/attribute_utility.h | 13 +- eval/eval/create_list_step.cc | 20 +- eval/eval/create_struct_step.cc | 21 +- eval/eval/function_step.cc | 8 +- eval/internal/interop_test.cc | 23 +- eval/public/BUILD | 3 +- .../portable_cel_expr_builder_factory_test.cc | 11 +- .../structs/proto_message_type_adapter.cc | 12 +- extensions/protobuf/memory_manager_test.cc | 76 +--- 21 files changed, 493 insertions(+), 932 deletions(-) delete mode 100644 base/internal/managed_memory.cc delete mode 100644 base/internal/managed_memory.h delete mode 100644 base/managed_memory.h create mode 100644 base/memory.h diff --git a/base/BUILD b/base/BUILD index eacde47b1..98612af41 100644 --- a/base/BUILD +++ b/base/BUILD @@ -19,13 +19,9 @@ package( licenses(["notice"]) -cc_library( +alias( name = "allocator", - hdrs = ["allocator.h"], - deps = [ - ":memory_manager", - "@com_google_absl//absl/base:core_headers", - ], + actual = ":memory", ) cc_test( @@ -63,7 +59,6 @@ cc_library( name = "handle", hdrs = ["handle.h"], deps = [ - ":memory_manager", "//base/internal:data", "//base/internal:handle", "@com_google_absl//absl/base:core_headers", @@ -100,17 +95,15 @@ cc_test( ) cc_library( - name = "managed_memory", - hdrs = ["managed_memory.h"], - deps = ["//base/internal:managed_memory"], -) - -cc_library( - name = "memory_manager", + name = "memory", srcs = ["memory_manager.cc"], - hdrs = ["memory_manager.h"], + hdrs = [ + "allocator.h", + "memory.h", + "memory_manager.h", + ], deps = [ - ":managed_memory", + ":handle", "//base/internal:data", "//base/internal:memory_manager", "//internal:no_destructor", @@ -118,11 +111,17 @@ cc_library( "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/synchronization", ], ) +alias( + name = "memory_manager", + actual = ":memory", +) + cc_test( name = "memory_manager_test", srcs = ["memory_manager_test.cc"], diff --git a/base/allocator.h b/base/allocator.h index 9a6069515..e75217f9c 100644 --- a/base/allocator.h +++ b/base/allocator.h @@ -15,118 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ #define THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ -#include -#include -#include -#include -#include -#include +// TODO(issues/5): delete -#include "absl/base/attributes.h" -#include "base/memory_manager.h" - -namespace cel { - -// STL allocator implementation which is backed by MemoryManager. -template -class Allocator final { - public: - using value_type = T; - using pointer = T*; - using const_pointer = const T*; - using reference = T&; - using const_reference = const T&; - using size_type = size_t; - using difference_type = ptrdiff_t; - using propagate_on_container_move_assignment = std::true_type; - using is_always_equal = std::false_type; - - template - struct rebind final { - using other = Allocator; - }; - - explicit Allocator( - ABSL_ATTRIBUTE_LIFETIME_BOUND MemoryManager& memory_manager) - : memory_manager_(memory_manager), - allocation_only_(memory_manager.allocation_only_) {} - - Allocator(const Allocator&) = default; - - template - Allocator(const Allocator& other) // NOLINT(google-explicit-constructor) - : memory_manager_(other.memory_manager_), - allocation_only_(other.allocation_only_) {} - - pointer allocate(size_type n) { - if (!memory_manager_.allocation_only_) { - return static_cast(::operator new( - n * sizeof(T), static_cast(alignof(T)))); - } - return static_cast(memory_manager_.Allocate(n * sizeof(T), alignof(T))); - } - - pointer allocate(size_type n, const void* hint) { - static_cast(hint); - return allocate(n); - } - - void deallocate(pointer p, size_type n) { - if (!allocation_only_) { - ::operator delete(static_cast(p), n * sizeof(T), - static_cast(alignof(T))); - } - } - - constexpr size_type max_size() const noexcept { - return std::numeric_limits::max() / sizeof(value_type); - } - - pointer address(reference x) const noexcept { return std::addressof(x); } - - const_pointer address(const_reference x) const noexcept { - return std::addressof(x); - } - - void construct(pointer p, const_reference val) { - ::new (static_cast(p)) T(val); - } - - template - void construct(U* p, Args&&... args) { - ::new (static_cast(p)) U(std::forward(args)...); - } - - void destroy(pointer p) { p->~T(); } - - template - void destroy(U* p) { - p->~U(); - } - - template - bool operator==(const Allocator& rhs) const { - return &memory_manager_ == &rhs.memory_manager_; - } - - template - bool operator!=(const Allocator& rhs) const { - return &memory_manager_ != &rhs.memory_manager_; - } - - private: - template - friend class Allocator; - - MemoryManager& memory_manager_; - // Ugh. This is here because of legacy behavior. MemoryManager& is guaranteed - // to exist during allocation, but not necessarily during deallocation. So we - // store the member variable from MemoryManager. This can go away once - // CelValue and friends are entirely gone and everybody is instantiating their - // own MemoryManager. - bool allocation_only_; -}; - -} // namespace cel +#include "base/memory.h" #endif // THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ diff --git a/base/handle.h b/base/handle.h index 2d14e8353..90b860917 100644 --- a/base/handle.h +++ b/base/handle.h @@ -19,16 +19,15 @@ #include #include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/utility/utility.h" #include "base/internal/data.h" #include "base/internal/handle.h" // IWYU pragma: export -#include "base/memory_manager.h" namespace cel { +class MemoryManager; + // `Handle` is a handle that shares ownership of the referenced `T`. It is valid // so long as there are 1 or more handles pointing to `T` and the // `AllocationManager` that constructed it is alive. @@ -251,6 +250,7 @@ class Handle final : private base_internal::HandlePolicy { friend class Handle; template friend struct base_internal::HandleFactory; + friend class MemoryManager; template explicit Handle(absl::in_place_t, Args&&... args) @@ -294,22 +294,7 @@ struct HandleFactory { // implementation. template static std::enable_if_t, Handle> Make( - MemoryManager& memory_manager, Args&&... args) { - static_assert(std::is_base_of_v, "T is not derived from Data"); - static_assert(std::is_base_of_v, "F is not derived from T"); -#if defined(__cpp_lib_is_pointer_interconvertible) && \ - __cpp_lib_is_pointer_interconvertible >= 201907L - // Only available in C++20. - static_assert(std::is_pointer_interconvertible_base_of_v, - "F must be pointer interconvertible to Data"); -#endif - auto managed_memory = memory_manager.New(std::forward(args)...); - if (ABSL_PREDICT_FALSE(managed_memory == nullptr)) { - return Handle(); - } - return Handle(absl::in_place, - *base_internal::ManagedMemoryRelease(managed_memory)); - } + MemoryManager& memory_manager, Args&&... args); }; } // namespace cel::base_internal diff --git a/base/internal/BUILD b/base/internal/BUILD index cc05c632f..00e01b1dd 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -45,18 +45,6 @@ cc_library( ], ) -cc_library( - name = "managed_memory", - srcs = ["managed_memory.cc"], - hdrs = ["managed_memory.h"], - deps = [ - ":data", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/numeric:bits", - ], -) - cc_library( name = "memory_manager", hdrs = [ diff --git a/base/internal/managed_memory.cc b/base/internal/managed_memory.cc deleted file mode 100644 index ce13dc587..000000000 --- a/base/internal/managed_memory.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "base/internal/managed_memory.h" - -#include -#include -#include - -#include "absl/base/config.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "absl/numeric/bits.h" - -namespace cel::base_internal { - -namespace { - -size_t AlignUp(size_t size, size_t align) { -#if ABSL_HAVE_BUILTIN(__builtin_align_up) - return __builtin_align_up(size, align); -#else - return (size + align - size_t{1}) & ~(align - size_t{1}); -#endif -} - -} // namespace - -std::pair ManagedMemoryState::New( - size_t size, size_t align, ManagedMemoryDestructor destructor) { - ABSL_ASSERT(size != 0); - ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. - if (ABSL_PREDICT_TRUE(align <= sizeof(ManagedMemoryState))) { - // Alignment requirements are less than the size of `ManagedMemoryState`, we - // can place `ManagedMemoryState` in front. - uint8_t* pointer = reinterpret_cast( - ::operator new(size + sizeof(ManagedMemoryState))); - ::new (pointer) ManagedMemoryState(destructor); - return {reinterpret_cast(pointer), - static_cast(pointer + sizeof(ManagedMemoryState))}; - } - // Alignment requirements are greater than the size of `ManagedMemoryState`, - // we need to place `ManagedMemoryState` at the back and pad to ensure - // `ManagedMemoryState` itself is aligned. - size_t adjusted_size = AlignUp(size, alignof(ManagedMemoryState)); - uint8_t* pointer = reinterpret_cast( - ::operator new(adjusted_size + sizeof(ManagedMemoryState))); - ::new (pointer + adjusted_size) ManagedMemoryState(destructor); - return {reinterpret_cast(pointer + adjusted_size), - static_cast(pointer)}; -} - -void ManagedMemoryState::Delete(void* pointer) { - ABSL_ASSERT(pointer != nullptr); - ABSL_ASSERT(this != pointer); - if (destructor_ != nullptr) { - (*destructor_)(pointer); - } - this->~ManagedMemoryState(); - ::operator delete(reinterpret_cast(this) < - static_cast(pointer) - ? static_cast(this) - : const_cast(pointer)); -} - -} // namespace cel::base_internal diff --git a/base/internal/managed_memory.h b/base/internal/managed_memory.h deleted file mode 100644 index dd7211b76..000000000 --- a/base/internal/managed_memory.h +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/numeric/bits.h" -#include "base/internal/data.h" - -namespace cel { - -class MemoryManager; - -namespace base_internal { - -template > -class ManagedMemory; - -template -T* ManagedMemoryRelease(ManagedMemory& managed_memory); - -// ManagedMemory implementation for T that is derived from Data and HeapData. -template -class ManagedMemory final { - private: - static_assert(std::is_base_of_v, - "T must be derived from HeapData"); - - public: - ManagedMemory() = default; - - explicit ManagedMemory(std::nullptr_t) : ManagedMemory() {} - - ManagedMemory(const ManagedMemory& other) : pointer_(other.pointer_) { - Ref(); - } - - template >> - ManagedMemory(const ManagedMemory& other) // NOLINT - : pointer_(other.pointer_) { - Ref(); - } - - ManagedMemory(ManagedMemory&& other) : ManagedMemory() { - std::swap(pointer_, other.pointer_); - } - - template >> - ManagedMemory(ManagedMemory&& other) // NOLINT - : ManagedMemory() { - std::swap(pointer_, other.pointer_); - } - - ~ManagedMemory() { Unref(); } - - ManagedMemory& operator=(const ManagedMemory& other) { - if (this != &other) { - other.Ref(); - Unref(); - pointer_ = other.pointer_; - } - return *this; - } - - template - std::enable_if_t, ManagedMemory&> // NOLINT - operator=(const ManagedMemory& other) { - if (this != &other) { - other.Ref(); - Unref(); - pointer_ = other.pointer_; - } - return *this; - } - - ManagedMemory& operator=(ManagedMemory&& other) { - if (this != &other) { - Unref(); - pointer_ = 0; - std::swap(pointer_, other.pointer_); - } - return *this; - } - - template - std::enable_if_t, ManagedMemory&> // NOLINT - operator=(ManagedMemory&& other) { - if (this != &other) { - Unref(); - pointer_ = 0; - std::swap(pointer_, other.pointer_); - } - return *this; - } - - T* get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return reinterpret_cast(pointer_ & kPointerMask); - } - - T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(get() != nullptr); - return *get(); - } - - T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(get() != nullptr); - return get(); - } - - explicit operator bool() const { return get() != nullptr; } - - ABSL_MUST_USE_RESULT T* release() { - if (pointer_ == 0) { - return nullptr; - } - ABSL_ASSERT((pointer_ & kPointerArenaAllocated) == kPointerArenaAllocated); - T* pointer = get(); - pointer_ = 0; - return pointer; - } - - private: - friend class cel::MemoryManager; - - template - friend F* ManagedMemoryRelease(ManagedMemory& managed_memory); - - explicit ManagedMemory(T* pointer) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(pointer)) >= 2); - pointer_ = - reinterpret_cast(pointer) | - (Metadata::IsArenaAllocated(*pointer) ? kPointerArenaAllocated : 0); - } - - void Ref() const { - if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0) { - Metadata::Ref(**this); - } - } - - void Unref() const { - if (pointer_ != 0 && (pointer_ & kPointerArenaAllocated) == 0 && - Metadata::Unref(**this)) { - delete static_cast(get()); - } - } - - uintptr_t pointer_ = 0; -}; - -template -T* ManagedMemoryRelease(ManagedMemory& managed_memory) { - T* pointer = managed_memory.get(); - managed_memory.pointer_ = 0; - return pointer; -} - -using ManagedMemoryDestructor = void (*)(void*); - -// Shared state used by `ManagedMemory` that holds the reference count -// and destructor to call when the reference count hits 0. `MemoryManager` -// places `T` and `ManagedMemoryState` in the same allocation. Whether -// `ManagedMemoryState` is before or after `T` depends on alignment requirements -// of `T`. -class ManagedMemoryState final { - public: - static std::pair New( - size_t size, size_t align, ManagedMemoryDestructor destructor); - - ManagedMemoryState() = delete; - ManagedMemoryState(const ManagedMemoryState&) = delete; - ManagedMemoryState(ManagedMemoryState&&) = delete; - ManagedMemoryState& operator=(const ManagedMemoryState&) = delete; - ManagedMemoryState& operator=(ManagedMemoryState&&) = delete; - - void Ref() { - const auto reference_count = - reference_count_.fetch_add(1, std::memory_order_relaxed); - ABSL_ASSERT(reference_count > 0); - } - - ABSL_MUST_USE_RESULT bool Unref() { - const auto reference_count = - reference_count_.fetch_sub(1, std::memory_order_seq_cst); - ABSL_ASSERT(reference_count > 0); - return reference_count == 1; - } - - void Delete(void* pointer); - - private: - explicit ManagedMemoryState(ManagedMemoryDestructor destructor) - : reference_count_(1), destructor_(destructor) {} - - mutable std::atomic reference_count_; - ManagedMemoryDestructor destructor_; -}; - -// ManagedMemory implementation for T that is not derived from Data. This is -// very similar to `std::shared_ptr`. -template -class ManagedMemory final { - public: - ManagedMemory() = default; - - explicit ManagedMemory(std::nullptr_t) : ManagedMemory() {} - - ManagedMemory(const ManagedMemory& other) - : pointer_(other.pointer_), state_(other.state_) { - Ref(); - } - - template >> - ManagedMemory(const ManagedMemory& other) // NOLINT - : pointer_(static_cast(other.pointer_)), state_(other.state_) { - Ref(); - } - - ManagedMemory(ManagedMemory&& other) : ManagedMemory() { - std::swap(pointer_, other.pointer_); - std::swap(state_, other.state_); - } - - template >> - ManagedMemory(ManagedMemory&& other) // NOLINT - : pointer_(static_cast(other.pointer_)), state_(other.state_) { - other.pointer_ = nullptr; - other.state_ = nullptr; - } - - ~ManagedMemory() { Unref(); } - - ManagedMemory& operator=(const ManagedMemory& other) { - if (this != &other) { - other.Ref(); - Unref(); - pointer_ = other.pointer_; - state_ = other.state_; - } - return *this; - } - - template - std::enable_if_t, ManagedMemory&> // NOLINT - operator=(const ManagedMemory& other) { - if (this != &other) { - other.Ref(); - Unref(); - pointer_ = static_cast(other.pointer_); - state_ = other.state_; - } - return *this; - } - - ManagedMemory& operator=(ManagedMemory&& other) { - if (this != &other) { - Unref(); - pointer_ = nullptr; - state_ = nullptr; - std::swap(pointer_, other.pointer_); - std::swap(state_, other.state_); - } - return *this; - } - - template - std::enable_if_t, ManagedMemory&> // NOLINT - operator=(ManagedMemory&& other) { - if (this != &other) { - Unref(); - pointer_ = static_cast(other.pointer_); - state_ = other.state_; - other.pointer_ = nullptr; - other.state_ = nullptr; - } - return *this; - } - - T* get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return pointer_; } - - T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(get() != nullptr); - return *get(); - } - - T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(get() != nullptr); - return get(); - } - - explicit operator bool() const { return get() != nullptr; } - - ABSL_MUST_USE_RESULT T* release() { - if (pointer_ == nullptr) { - return nullptr; - } - ABSL_ASSERT(state_ == nullptr); - T* pointer = pointer_; - pointer_ = nullptr; - return pointer; - } - - private: - friend class cel::MemoryManager; - - ManagedMemory(T* pointer, ManagedMemoryState* state) - : pointer_(pointer), state_(state) {} - - void Ref() const { - if (state_ != nullptr) { - state_->Ref(); - } - } - - void Unref() const { - if (state_ != nullptr && state_->Unref()) { - state_->Delete(const_cast(static_cast(get()))); - } - } - - T* pointer_ = nullptr; - ManagedMemoryState* state_ = nullptr; -}; - -template -constexpr bool operator==(const ManagedMemory& lhs, std::nullptr_t) { - return !static_cast(lhs); -} - -template -constexpr bool operator==(std::nullptr_t, const ManagedMemory& rhs) { - return !static_cast(rhs); -} - -template -constexpr bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { - return !operator==(lhs, nullptr); -} - -template -constexpr bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { - return !operator==(nullptr, rhs); -} - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MANAGED_MEMORY_H_ diff --git a/base/managed_memory.h b/base/managed_memory.h deleted file mode 100644 index ae8628a1b..000000000 --- a/base/managed_memory.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ -#define THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ - -#include - -#include "base/internal/managed_memory.h" - -namespace cel { - -// `ManagedMemory` is a smart pointer which ensures any applicable object -// destructors and deallocation are eventually performed. Copying does not -// actually copy the underlying T, instead a pointer is copied and optionally -// reference counted. Moving does not actually move the underlying T, instead a -// pointer is moved. -// -// TODO(issues/5): consider feature parity with std::unique_ptr -template -using ManagedMemory = base_internal::ManagedMemory; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_MANAGED_MEMORY_H_ diff --git a/base/memory.h b/base/memory.h new file mode 100644 index 000000000..b564429db --- /dev/null +++ b/base/memory.h @@ -0,0 +1,377 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_MEMORY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_MEMORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/log/die_if_null.h" +#include "base/handle.h" +#include "base/internal/data.h" +#include "base/internal/memory_manager.h" +#include "internal/rtti.h" + +namespace cel { + +template +class Allocator; +class MemoryManager; +class GlobalMemoryManager; +class ArenaMemoryManager; + +namespace extensions { +class ProtoMemoryManager; +} + +// `UniqueRef` is similar to `std::unique_ptr`, but works with `MemoryManager` +// and is more strict to help prevent misuse. It is undefined behavior to access +// `UniqueRef` after being moved. +template +class UniqueRef final { + public: + UniqueRef() = delete; + + UniqueRef(const UniqueRef&) = delete; + + UniqueRef(UniqueRef&& other) noexcept + : ref_(other.ref_), owned_(other.owned_) { + other.ref_ = nullptr; + other.owned_ = false; + } + + template >> + UniqueRef(UniqueRef&& other) noexcept // NOLINT + : ref_(other.ref_), owned_(other.owned_) { + other.ref_ = nullptr; + other.owned_ = false; + } + + ~UniqueRef() { + if (ref_ != nullptr) { + if (owned_) { + delete ref_; + } else { + ref_->~T(); + } + } + } + + UniqueRef& operator=(const UniqueRef&) = delete; + + UniqueRef& operator=(UniqueRef&& other) noexcept { + if (ABSL_PREDICT_TRUE(this != &other)) { + if (ref_ != nullptr) { + if (owned_) { + delete ref_; + } else { + ref_->~T(); + } + } + ref_ = other.ref_; + owned_ = other.owned_; + other.ref_ = nullptr; + other.owned_ = false; + } + return *this; + } + + template >> + UniqueRef& operator=(UniqueRef&& other) noexcept { // NOLINT + if (ABSL_PREDICT_TRUE(this != &other)) { + if (ref_ != nullptr) { + if (owned_) { + delete ref_; + } else { + ref_->~T(); + } + } + ref_ = other.ref_; + owned_ = other.owned_; + other.ref_ = nullptr; + other.owned_ = false; + } + return *this; + } + + T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(ref_ != nullptr); + ABSL_ASSUME(ref_ != nullptr); + return ref_; + } + + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } + + T& get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(ref_ != nullptr); + ABSL_ASSUME(ref_ != nullptr); + return *ref_; + } + + private: + template + friend class UniqueRef; + friend class MemoryManager; + + UniqueRef(T* ref, bool owned) ABSL_ATTRIBUTE_NONNULL() + : ref_(ABSL_DIE_IF_NULL(ref)), // Crash OK + owned_(owned) {} + + T* ref_; + bool owned_; +}; + +template +ABSL_MUST_USE_RESULT UniqueRef MakeUnique(MemoryManager& memory_manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + Args&&... args); + +// `MemoryManager` is an abstraction over memory management that supports +// different allocation strategies. +class MemoryManager { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager& Global(); + + virtual ~MemoryManager() = default; + + private: + friend class GlobalMemoryManager; + friend class ArenaMemoryManager; + friend class extensions::ProtoMemoryManager; + template + friend class Allocator; + template + friend struct base_internal::HandleFactory; + template + friend UniqueRef MakeUnique(MemoryManager&, Args&&...); + + // Only for use by GlobalMemoryManager and ArenaMemoryManager. + explicit MemoryManager(bool allocation_only) + : allocation_only_(allocation_only) {} + + // Allocates and constructs `T`. + template + Handle AllocateHandle(Args&&... args) + ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + static_assert(std::is_base_of_v); + T* pointer; + if (!allocation_only_) { + pointer = new T(std::forward(args)...); + base_internal::Metadata::SetReferenceCounted(*pointer); + } else { + pointer = ::new (Allocate(sizeof(T), alignof(T))) + T(std::forward(args)...); + if constexpr (!std::is_trivially_destructible_v) { + if constexpr (base_internal::HasIsDestructorSkippable::value) { + if (!T::IsDestructorSkippable(*pointer)) { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); + } + } else { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); + } + } + base_internal::Metadata::SetArenaAllocated(*pointer); + } + return Handle(absl::in_place, *pointer); + } + + template + UniqueRef AllocateUnique(Args&&... args) + ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + static_assert(!std::is_base_of_v); + T* ptr; + if (allocation_only_) { + ptr = ::new (Allocate(sizeof(T), alignof(T))) + T(std::forward(args)...); + } else { + ptr = new T(std::forward(args)...); + } + return UniqueRef(ptr, !allocation_only_); + } + + // These are virtual private, ensuring only `MemoryManager` calls these. + + // Allocates memory of at least size `size` in bytes that is at least as + // aligned as `align`. + virtual void* Allocate(size_t size, size_t align) = 0; + + // Registers a destructor to be run upon destruction of the memory management + // implementation. + virtual void OwnDestructor(void* pointer, void (*destruct)(void*)) = 0; + + virtual internal::TypeInfo TypeId() const { return internal::TypeInfo(); } + + const bool allocation_only_; +}; + +// Allocates and constructs `T`. +template +UniqueRef MakeUnique(MemoryManager& memory_manager, Args&&... args) { + return memory_manager.AllocateUnique(std::forward(args)...); +} + +// Base class for all arena-based memory managers. +class ArenaMemoryManager : public MemoryManager { + public: + // Returns the default implementation of an arena-based memory manager. In + // most cases it should be good enough, however you should not rely on its + // performance characteristics. + static std::unique_ptr Default(); + + protected: + ArenaMemoryManager() : ArenaMemoryManager(true) {} + + private: + friend class extensions::ProtoMemoryManager; + + // Private so that only ProtoMemoryManager can use it for legacy reasons. All + // other derivations of ArenaMemoryManager should be allocation-only. + explicit ArenaMemoryManager(bool allocation_only) + : MemoryManager(allocation_only) {} +}; + +// STL allocator implementation which is backed by MemoryManager. +template +class Allocator final { + public: + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using size_type = size_t; + using difference_type = ptrdiff_t; + using propagate_on_container_move_assignment = std::true_type; + using is_always_equal = std::false_type; + + template + struct rebind final { + using other = Allocator; + }; + + explicit Allocator( + ABSL_ATTRIBUTE_LIFETIME_BOUND MemoryManager& memory_manager) + : memory_manager_(memory_manager), + allocation_only_(memory_manager.allocation_only_) {} + + Allocator(const Allocator&) = default; + + template + Allocator(const Allocator& other) // NOLINT(google-explicit-constructor) + : memory_manager_(other.memory_manager_), + allocation_only_(other.allocation_only_) {} + + pointer allocate(size_type n) { + if (!memory_manager_.allocation_only_) { + return static_cast(::operator new( + n * sizeof(T), static_cast(alignof(T)))); + } + return static_cast(memory_manager_.Allocate(n * sizeof(T), alignof(T))); + } + + pointer allocate(size_type n, const void* hint) { + static_cast(hint); + return allocate(n); + } + + void deallocate(pointer p, size_type n) { + if (!allocation_only_) { + ::operator delete(static_cast(p), n * sizeof(T), + static_cast(alignof(T))); + } + } + + constexpr size_type max_size() const noexcept { + return std::numeric_limits::max() / sizeof(value_type); + } + + pointer address(reference x) const noexcept { return std::addressof(x); } + + const_pointer address(const_reference x) const noexcept { + return std::addressof(x); + } + + void construct(pointer p, const_reference val) { + ::new (static_cast(p)) T(val); + } + + template + void construct(U* p, Args&&... args) { + ::new (static_cast(p)) U(std::forward(args)...); + } + + void destroy(pointer p) { p->~T(); } + + template + void destroy(U* p) { + p->~U(); + } + + template + bool operator==(const Allocator& rhs) const { + return &memory_manager_ == &rhs.memory_manager_; + } + + template + bool operator!=(const Allocator& rhs) const { + return &memory_manager_ != &rhs.memory_manager_; + } + + private: + template + friend class Allocator; + + MemoryManager& memory_manager_; + // Ugh. This is here because of legacy behavior. MemoryManager& is guaranteed + // to exist during allocation, but not necessarily during deallocation. So we + // store the member variable from MemoryManager. This can go away once + // CelValue and friends are entirely gone and everybody is instantiating their + // own MemoryManager. + bool allocation_only_; +}; + +namespace base_internal { + +template +template +std::enable_if_t, Handle> +HandleFactory::Make(MemoryManager& memory_manager, Args&&... args) { + static_assert(std::is_base_of_v, "T is not derived from Data"); + static_assert(std::is_base_of_v, "F is not derived from T"); +#if defined(__cpp_lib_is_pointer_interconvertible) && \ + __cpp_lib_is_pointer_interconvertible >= 201907L + // Only available in C++20. + static_assert(std::is_pointer_interconvertible_base_of_v, + "F must be pointer interconvertible to Data"); +#endif + return memory_manager.AllocateHandle(std::forward(args)...); +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_H_ diff --git a/base/memory_manager.h b/base/memory_manager.h index c296febdf..f73e9f84b 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -15,137 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ -#include -#include -#include -#include +// TODO(issues/5): delete -#include "absl/base/attributes.h" -#include "base/internal/data.h" -#include "base/internal/memory_manager.h" -#include "base/managed_memory.h" -#include "internal/rtti.h" - -namespace cel { - -template -class Allocator; -class MemoryManager; -class GlobalMemoryManager; -class ArenaMemoryManager; - -namespace extensions { -class ProtoMemoryManager; -} - -// `MemoryManager` is an abstraction over memory management that supports -// different allocation strategies. -class MemoryManager { - public: - ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager& Global(); - - virtual ~MemoryManager() = default; - - // Allocates and constructs `T`. In the event of an allocation failure nullptr - // is returned. - template - std::enable_if_t, ManagedMemory> - New(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { - static_assert(std::is_base_of_v, - "T must only be stored inline"); - if (!allocation_only_) { - T* pointer = new T(std::forward(args)...); - base_internal::Metadata::SetReferenceCounted(*pointer); - return ManagedMemory(pointer); - } - void* pointer = Allocate(sizeof(T), alignof(T)); - ::new (pointer) T(std::forward(args)...); - if constexpr (!std::is_trivially_destructible_v) { - if constexpr (base_internal::HasIsDestructorSkippable::value) { - if (!T::IsDestructorSkippable(*static_cast(pointer))) { - OwnDestructor(pointer, - &base_internal::MemoryManagerDestructor::Destruct); - } - } else { - OwnDestructor(pointer, - &base_internal::MemoryManagerDestructor::Destruct); - } - } - base_internal::Metadata::SetArenaAllocated(*static_cast(pointer)); - return ManagedMemory(static_cast(pointer)); - } - - // Allocates and constructs `T`. In the event of an allocation failure nullptr - // is returned. - template - std::enable_if_t>, - ManagedMemory> - New(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { - if (!allocation_only_) { - base_internal::ManagedMemoryDestructor destructor = nullptr; - if constexpr (!std::is_trivially_destructible_v) { - destructor = &base_internal::MemoryManagerDestructor::Destruct; - } - auto [state, pointer] = base_internal::ManagedMemoryState::New( - sizeof(T), alignof(T), destructor); - ::new (pointer) T(std::forward(args)...); - return ManagedMemory(reinterpret_cast(pointer), state); - } - void* pointer = Allocate(sizeof(T), alignof(T)); - ::new (pointer) T(std::forward(args)...); - if constexpr (!std::is_trivially_destructible_v) { - OwnDestructor(pointer, - &base_internal::MemoryManagerDestructor::Destruct); - } - return ManagedMemory(reinterpret_cast(pointer), nullptr); - } - - private: - friend class GlobalMemoryManager; - friend class ArenaMemoryManager; - friend class extensions::ProtoMemoryManager; - template - friend class Allocator; - - // Only for use by GlobalMemoryManager and ArenaMemoryManager. - explicit MemoryManager(bool allocation_only) - : allocation_only_(allocation_only) {} - - // These are virtual private, ensuring only `MemoryManager` calls these. - - // Allocates memory of at least size `size` in bytes that is at least as - // aligned as `align`. - virtual void* Allocate(size_t size, size_t align) = 0; - - // Registers a destructor to be run upon destruction of the memory management - // implementation. - virtual void OwnDestructor(void* pointer, void (*destruct)(void*)) = 0; - - virtual internal::TypeInfo TypeId() const { return internal::TypeInfo(); } - - const bool allocation_only_; -}; - -// Base class for all arena-based memory managers. -class ArenaMemoryManager : public MemoryManager { - public: - // Returns the default implementation of an arena-based memory manager. In - // most cases it should be good enough, however you should not rely on its - // performance characteristics. - static std::unique_ptr Default(); - - protected: - ArenaMemoryManager() : ArenaMemoryManager(true) {} - - private: - friend class extensions::ProtoMemoryManager; - - // Private so that only ProtoMemoryManager can use it for legacy reasons. All - // other derivations of ArenaMemoryManager should be allocation-only. - explicit ArenaMemoryManager(bool allocation_only) - : MemoryManager(allocation_only) {} -}; - -} // namespace cel +#include "base/memory.h" #endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc index fe20fb02b..c31cca027 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_manager_test.cc @@ -21,15 +21,6 @@ namespace cel { namespace { -struct TriviallyDestructible final {}; - -TEST(GlobalMemoryManager, TriviallyDestructible) { - EXPECT_TRUE(std::is_trivially_destructible_v); - auto managed = MemoryManager::Global().New(); - EXPECT_NE(managed, nullptr); - EXPECT_NE(nullptr, managed); -} - struct NotTriviallyDestuctible final { ~NotTriviallyDestuctible() { Delete(); } @@ -37,28 +28,17 @@ struct NotTriviallyDestuctible final { }; TEST(GlobalMemoryManager, NotTriviallyDestuctible) { - EXPECT_FALSE(std::is_trivially_destructible_v); - auto managed = MemoryManager::Global().New(); - EXPECT_NE(managed, nullptr); - EXPECT_NE(nullptr, managed); + auto managed = MakeUnique(MemoryManager::Global()); EXPECT_CALL(*managed, Delete()); } -TEST(ManagedMemory, Null) { - EXPECT_EQ(ManagedMemory(), nullptr); - EXPECT_EQ(nullptr, ManagedMemory()); -} - -struct LargeStruct { - char padding[4096 - alignof(char)]; -}; - -TEST(DefaultArenaMemoryManager, OddSizes) { +TEST(ArenaMemoryManager, NotTriviallyDestuctible) { auto memory_manager = ArenaMemoryManager::Default(); - size_t page_size = base_internal::GetPageSize(); - for (size_t allocated = 0; allocated <= page_size; - allocated += sizeof(LargeStruct)) { - static_cast(memory_manager->New()); + { + // Destructor is called when UniqueRef is destructed, not on MemoryManager + // destruction. + auto managed = MakeUnique(*memory_manager); + EXPECT_CALL(*managed, Delete()); } } diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 4880a208a..33a3a8b19 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -202,6 +202,7 @@ cc_library( "//eval/public:cel_function_registry", "//eval/public:cel_value", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//runtime:activation_interface", "//runtime:function_overload_reference", @@ -260,6 +261,7 @@ cc_library( "//base:handle", "//eval/internal:interop", "//eval/public/containers:container_backed_list_impl", + "//extensions/protobuf:memory_manager", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -280,6 +282,7 @@ cc_library( "//eval/internal:interop", "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -709,6 +712,7 @@ cc_library( "//base:memory_manager", "//base:value", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 40bf628c4..27c1afea4 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -4,6 +4,7 @@ #include "base/attribute_set.h" #include "base/values/unknown_value.h" +#include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { @@ -71,8 +72,9 @@ const UnknownSet* AttributeUtility::MergeUnknowns( return initial_set; } - return memory_manager_.New(std::move(result_set).value()) - .release(); + return google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager_), + std::move(result_set).value()); } // Creates merged UnknownAttributeSet. @@ -118,9 +120,26 @@ const UnknownSet* AttributeUtility::MergeUnknowns( result_set, UnknownSet(unknown_value->attribute_set(), unknown_value->function_result_set())); } - return memory_manager_.New(std::move(result_set)).release(); + return google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager_), + std::move(result_set)); } return MergeUnknowns(args, initial_set); } +const UnknownSet* AttributeUtility::CreateUnknownSet( + cel::Attribute attr) const { + return google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager_), + UnknownAttributeSet({std::move(attr)})); +} + +const UnknownSet* AttributeUtility::CreateUnknownSet( + const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, + absl::Span> args) const { + return google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager_), + cel::FunctionResultSet(cel::FunctionResult(fn_descriptor, expr_id))); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index a7e003ac6..2dc8260fa 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -71,21 +71,12 @@ class AttributeUtility { bool use_partial) const; // Create an initial UnknownSet from a single attribute. - const UnknownSet* CreateUnknownSet(cel::Attribute attr) const { - return memory_manager_ - .New(UnknownAttributeSet({std::move(attr)})) - .release(); - } + const UnknownSet* CreateUnknownSet(cel::Attribute attr) const; // Create an initial UnknownSet from a single missing function call. const UnknownSet* CreateUnknownSet( const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, - absl::Span> args) const { - return memory_manager_ - .New( - cel::FunctionResultSet(cel::FunctionResult(fn_descriptor, expr_id))) - .release(); - } + absl::Span> args) const; private: absl::Span unknown_patterns_; diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index be6ce063e..e0124d3dd 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -11,6 +11,7 @@ #include "eval/eval/mutable_list_impl.h" #include "eval/internal/interop.h" #include "eval/public/containers/container_backed_list_impl.h" +#include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { @@ -71,20 +72,19 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { } } + auto* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( + frame->memory_manager()); + if (immutable_) { // TODO(issues/5): switch to new cel::ListValue in phase 2 - result = CreateLegacyListValue( - frame->memory_manager() - .New( - ModernValueToLegacyValueOrDie(frame->memory_manager(), args)) - .release()); + result = + CreateLegacyListValue(google::protobuf::Arena::Create( + arena, + ModernValueToLegacyValueOrDie(frame->memory_manager(), args))); } else { // TODO(issues/5): switch to new cel::ListValue in phase 2 - result = CreateLegacyListValue( - frame->memory_manager() - .New( - ModernValueToLegacyValueOrDie(frame->memory_manager(), args)) - .release()); + result = CreateLegacyListValue(google::protobuf::Arena::Create( + arena, ModernValueToLegacyValueOrDie(frame->memory_manager(), args))); } frame->value_stack().Pop(list_size_); frame->value_stack().Push(std::move(result)); diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index ce101de00..70d3203cd 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -14,6 +14,7 @@ #include "eval/internal/interop.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { @@ -106,10 +107,10 @@ absl::Status CreateStructStepForMessage::Evaluate(ExecutionFrame* frame) const { if (status_or_result.ok()) { result = std::move(status_or_result).value(); } else { - result = CreateErrorValueFromView( - frame->memory_manager() - .New(status_or_result.status()) - .release()); + result = CreateErrorValueFromView(google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena( + frame->memory_manager()), + status_or_result.status())); } frame->value_stack().Pop(entries_.size()); frame->value_stack().Push(std::move(result)); @@ -131,7 +132,9 @@ absl::StatusOr> CreateStructStepForMap::DoEvaluate( } // TODO(issues/5): switch to new cel::MapValue in phase 2 - auto map_builder = frame->memory_manager().New(); + auto* map_builder = google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena( + frame->memory_manager())); for (size_t i = 0; i < entry_count_; i += 1) { int map_key_index = 2 * i; @@ -144,12 +147,14 @@ absl::StatusOr> CreateStructStepForMap::DoEvaluate( map_key, cel::interop_internal::ModernValueToLegacyValueOrDie( frame->memory_manager(), args[map_value_index])); if (!key_status.ok()) { - return CreateErrorValueFromView( - frame->memory_manager().New(key_status).release()); + return CreateErrorValueFromView(google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena( + frame->memory_manager()), + key_status)); } } - return CreateLegacyMapValue(map_builder.release()); + return CreateLegacyMapValue(map_builder); } absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 230bddf2a..9af0d05b1 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -30,6 +30,7 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "runtime/activation_interface.h" #include "runtime/function_overload_reference.h" @@ -89,9 +90,10 @@ std::vector> CheckForPartialUnknowns( auto attr_set = frame->attribute_utility().CheckForUnknowns( attrs.subspan(i, 1), /*use_partial=*/true); if (!attr_set.empty()) { - auto unknown_set = frame->memory_manager() - .New(std::move(attr_set)) - .release(); + auto unknown_set = google::protobuf::Arena::Create( + cel::extensions::ProtoMemoryManager::CastToProtoArena( + frame->memory_manager()), + std::move(attr_set)); result.push_back( cel::interop_internal::CreateUnknownValueFromView(unknown_set)); } else { diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index a02e05cc3..25399c5b4 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -368,11 +368,10 @@ TEST(ValueInterop, ListFromLegacy) { TypeFactory type_factory(memory_manager); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); - auto legacy_value = CelValue::CreateList( - memory_manager - .New( - std::vector{CelValue::CreateInt64(0)}) - .release()); + auto legacy_value = + CelValue::CreateList(google::protobuf::Arena::Create< + google::api::expr::runtime::ContainerBackedListImpl>( + &arena, std::vector{CelValue::CreateInt64(0)})); ASSERT_OK_AND_ASSIGN(auto value, FromLegacyValue(&arena, legacy_value)); EXPECT_TRUE(value->Is()); EXPECT_EQ(value.As()->size(), 1); @@ -457,11 +456,10 @@ TEST(ValueInterop, LegacyListRoundtrip) { TypeFactory type_factory(memory_manager); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); - auto value = CelValue::CreateList( - memory_manager - .New( - std::vector{CelValue::CreateInt64(0)}) - .release()); + auto value = + CelValue::CreateList(google::protobuf::Arena::Create< + google::api::expr::runtime::ContainerBackedListImpl>( + &arena, std::vector{CelValue::CreateInt64(0)})); ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); ASSERT_OK_AND_ASSIGN(auto legacy_value, ToLegacyValue(&arena, modern_value)); EXPECT_EQ(value.ListOrDie(), legacy_value.ListOrDie()); @@ -474,7 +472,7 @@ TEST(ValueInterop, MapFromLegacy) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); auto* legacy_map = - memory_manager.New().release(); + google::protobuf::Arena::Create(&arena); ASSERT_OK(legacy_map->Add(CelValue::CreateInt64(1), CelValue::CreateStringView("foo"))); auto legacy_value = CelValue::CreateMap(legacy_map); @@ -605,8 +603,7 @@ TEST(ValueInterop, LegacyMapRoundtrip) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); auto value = CelValue::CreateMap( - memory_manager.New() - .release()); + google::protobuf::Arena::Create(&arena)); ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); ASSERT_OK_AND_ASSIGN(auto legacy_value, ToLegacyValue(&arena, modern_value)); EXPECT_EQ(value.MapOrDie(), legacy_value.MapOrDie()); diff --git a/eval/public/BUILD b/eval/public/BUILD index cd82c52be..c761311aa 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -1246,14 +1246,13 @@ cc_test( "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:legacy_type_provider", "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:proto_time_encoding", "//internal:testing", "//parser", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index 5dbfdeb77..9d1857043 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -32,6 +32,7 @@ #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/structs/legacy_type_provider.h" #include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" #include "internal/proto_time_encoding.h" #include "internal/testing.h" @@ -370,8 +371,9 @@ const LegacyTypeAccessApis* DemoTypeInfo::GetAccessApis( absl::StatusOr DemoTimestamp::NewInstance( cel::MemoryManager& memory_manager) const { - auto ts = memory_manager.New(); - return CelValue::MessageWrapper::Builder(ts.release()); + auto* ts = google::protobuf::Arena::CreateMessage( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager)); + return CelValue::MessageWrapper::Builder(ts); } absl::StatusOr DemoTimestamp::AdaptFromWellKnownType( cel::MemoryManager& memory_manager, @@ -421,8 +423,9 @@ DemoTestMessage::DemoTestMessage(const DemoTypeProvider* owning_provider) absl::StatusOr DemoTestMessage::NewInstance( cel::MemoryManager& memory_manager) const { - auto ts = memory_manager.New(); - return CelValue::MessageWrapper::Builder(ts.release()); + auto* ts = google::protobuf::Arena::CreateMessage( + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager)); + return CelValue::MessageWrapper::Builder(ts); } absl::Status DemoTestMessage::SetField( diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a2fce9802..556108179 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -134,15 +134,15 @@ absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); if (field_desc->is_map()) { - auto map = memory_manager.New( - message, field_desc, &MessageCelValueFactory, arena); + auto* map = google::protobuf::Arena::Create( + arena, message, field_desc, &MessageCelValueFactory, arena); - return CelValue::CreateMap(map.release()); + return CelValue::CreateMap(map); } if (field_desc->is_repeated()) { - auto list = memory_manager.New( - message, field_desc, &MessageCelValueFactory, arena); - return CelValue::CreateList(list.release()); + auto* list = google::protobuf::Arena::Create( + arena, message, field_desc, &MessageCelValueFactory, arena); + return CelValue::CreateList(list); } CEL_ASSIGN_OR_RETURN( diff --git a/extensions/protobuf/memory_manager_test.cc b/extensions/protobuf/memory_manager_test.cc index 1290d8b7b..62574cf5d 100644 --- a/extensions/protobuf/memory_manager_test.cc +++ b/extensions/protobuf/memory_manager_test.cc @@ -14,91 +14,27 @@ #include "extensions/protobuf/memory_manager.h" -#include "google/protobuf/struct.pb.h" #include "google/protobuf/arena.h" #include "internal/testing.h" namespace cel::extensions { namespace { -struct NotArenaCompatible final { - ~NotArenaCompatible() { Delete(); } - - MOCK_METHOD(void, Delete, (), ()); -}; - -TEST(ProtoMemoryManager, ArenaConstructable) { - google::protobuf::Arena arena; - ProtoMemoryManager memory_manager(&arena); - EXPECT_TRUE( - google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(memory_manager); - EXPECT_NE(object, nullptr); -} - -TEST(ProtoMemoryManager, NotArenaConstructable) { - google::protobuf::Arena arena; - ProtoMemoryManager memory_manager(&arena); - EXPECT_FALSE( - google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(memory_manager); - EXPECT_NE(object, nullptr); - EXPECT_CALL(*object, Delete()); -} - -TEST(ProtoMemoryManagerNoArena, ArenaConstructable) { - ProtoMemoryManager memory_manager(nullptr); - EXPECT_TRUE( - google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(memory_manager); - EXPECT_NE(object, nullptr); - delete object; -} - -TEST(ProtoMemoryManagerNoArena, NotArenaConstructable) { - ProtoMemoryManager memory_manager(nullptr); - EXPECT_FALSE( - google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(memory_manager); - EXPECT_NE(object, nullptr); - EXPECT_CALL(*object, Delete()); - delete object; -} - -struct TriviallyDestructible final {}; - struct NotTriviallyDestuctible final { ~NotTriviallyDestuctible() { Delete(); } MOCK_METHOD(void, Delete, (), ()); }; -TEST(ProtoMemoryManager, TriviallyDestructible) { - google::protobuf::Arena arena; - ProtoMemoryManager memory_manager(&arena); - EXPECT_TRUE(std::is_trivially_destructible_v); - auto managed = memory_manager.New(); -} - TEST(ProtoMemoryManager, NotTriviallyDestuctible) { google::protobuf::Arena arena; ProtoMemoryManager memory_manager(&arena); - EXPECT_FALSE(std::is_trivially_destructible_v); - auto managed = memory_manager.New(); - EXPECT_CALL(*managed, Delete()); -} - -TEST(ProtoMemoryManagerNoArena, TriviallyDestructible) { - ProtoMemoryManager memory_manager(nullptr); - EXPECT_TRUE(std::is_trivially_destructible_v); - auto managed = memory_manager.New(); -} - -TEST(ProtoMemoryManagerNoArena, NotTriviallyDestuctible) { - ProtoMemoryManager memory_manager(nullptr); - EXPECT_FALSE(std::is_trivially_destructible_v); - auto managed = memory_manager.New(); - EXPECT_CALL(*managed, Delete()); + { + // Destructor is called when UniqueRef is destructed, not on MemoryManager + // destruction. + auto managed = MakeUnique(memory_manager); + EXPECT_CALL(*managed, Delete()); + } } } // namespace From 058603e3c2edda4e26314fe262fe56a6936ff162 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 4 May 2023 19:36:55 +0000 Subject: [PATCH 242/303] Improve safety and efficiency of handle construction PiperOrigin-RevId: 529484601 --- base/BUILD | 2 - base/handle.h | 33 +++++++------ base/internal/BUILD | 1 + base/internal/data.h | 102 +++++++++++++++++++++++++++++++++++------ base/internal/handle.h | 32 ++++++++++++- base/memory.h | 22 ++++----- base/type.h | 13 ++++-- base/value.h | 14 ++++-- 8 files changed, 171 insertions(+), 48 deletions(-) diff --git a/base/BUILD b/base/BUILD index 98612af41..e493b09a6 100644 --- a/base/BUILD +++ b/base/BUILD @@ -63,7 +63,6 @@ cc_library( "//base/internal:handle", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/utility", ], ) @@ -196,7 +195,6 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@com_google_absl//absl/utility", ], ) diff --git a/base/handle.h b/base/handle.h index 90b860917..24c8c8b52 100644 --- a/base/handle.h +++ b/base/handle.h @@ -20,7 +20,6 @@ #include "absl/base/attributes.h" #include "absl/log/absl_check.h" -#include "absl/utility/utility.h" #include "base/internal/data.h" #include "base/internal/handle.h" // IWYU pragma: export @@ -252,9 +251,17 @@ class Handle final : private base_internal::HandlePolicy { friend struct base_internal::HandleFactory; friend class MemoryManager; - template - explicit Handle(absl::in_place_t, Args&&... args) - : impl_(std::forward(args)...) {} + template + explicit Handle(base_internal::InPlaceStoredInline tag, Args&&... args) + : impl_(tag, std::forward(args)...) {} + + Handle(base_internal::InPlaceArenaAllocated tag, + typename Impl::base_type& arg) + : impl_(tag, arg) {} + + Handle(base_internal::InPlaceReferenceCounted tag, + typename Impl::base_type& arg) + : impl_(tag, arg) {} Impl impl_; }; @@ -268,32 +275,32 @@ namespace cel::base_internal { template struct HandleFactory { + static_assert(IsDerivedDataV); + // Constructs a handle whose underlying object is stored in the // handle itself. template - static std::enable_if_t, Handle> Make( + static std::enable_if_t, Handle> Make( Args&&... args) { - static_assert(std::is_base_of_v, "T is not derived from Data"); static_assert(std::is_base_of_v, "F is not derived from T"); - return Handle(absl::in_place, absl::in_place_type, - std::forward(args)...); + return Handle(kInPlaceStoredInline, std::forward(args)...); } + // Constructs a handle whose underlying object is stored in the // handle itself. template - static std::enable_if_t, void> MakeAt( + static std::enable_if_t, void> MakeAt( void* address, Args&&... args) { - static_assert(std::is_base_of_v, "T is not derived from Data"); static_assert(std::is_base_of_v, "F is not derived from T"); - ::new (address) Handle(absl::in_place, absl::in_place_type, - std::forward(args)...); + ::new (address) + Handle(kInPlaceStoredInline, std::forward(args)...); } // Constructs a handle whose underlying object is heap allocated // and potentially reference counted, depending on the memory manager // implementation. template - static std::enable_if_t, Handle> Make( + static std::enable_if_t, Handle> Make( MemoryManager& memory_manager, Args&&... args); }; diff --git a/base/internal/BUILD b/base/internal/BUILD index 00e01b1dd..6c364a30b 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -30,6 +30,7 @@ cc_library( hdrs = ["data.h"], deps = [ "//base:kind", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", ], diff --git a/base/internal/data.h b/base/internal/data.h index 5dba01e49..eb1c9121c 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -20,10 +20,10 @@ #include #include #include -#include #include #include "absl/base/attributes.h" +#include "absl/base/casts.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/numeric/bits.h" @@ -74,7 +74,7 @@ inline constexpr uintptr_t kInlineVariantBits = uintptr_t{0xf} << kInlineVariantShift; // We assert some expectations we have around alignment, size, and trivial -// destructability. +// destructibility. static_assert(sizeof(uintptr_t) == sizeof(std::atomic), "uintptr_t and std::atomic must have the same size"); static_assert(sizeof(void*) == sizeof(uintptr_t), @@ -368,7 +368,13 @@ using SelectMetadata = typename SelectMetadataImpl::type; template union alignas(Align) AnyDataStorage final { +#ifdef NDEBUG + // Only need to clear the pointer for this to appear as empty. AnyDataStorage() : pointer(0) {} +#else + // In debug builds we clear the entire storage to help identify misuse. + AnyDataStorage() { std::memset(buffer, '\0', sizeof(buffer)); } +#endif uintptr_t pointer; uint8_t buffer[Size]; @@ -464,11 +470,11 @@ struct AnyData final { } Data* get_inline() const { - return static_cast(const_cast(buffer())); + return absl::bit_cast(const_cast(buffer())); } Data* get_heap() const { - return reinterpret_cast(pointer() & kPointerMask); + return absl::bit_cast(pointer() & kPointerMask); } // Copy the bytes from other, similar to `std::memcpy`. @@ -484,33 +490,47 @@ struct AnyData final { template void Destruct() { + static_assert(sizeof(T) <= kSize); + static_assert(alignof(T) <= kAlign); ABSL_ASSERT(IsStoredInline()); static_cast(get_inline())->~T(); } void Clear() { +#ifdef NDEBUG // We only need to clear the first `sizeof(uintptr_t)` bytes as that is // consulted to determine locality. set_pointer(0); +#else + // In debug builds, we clear all the storage to help identify misuse. + std::memset(buffer(), '\0', kSize); +#endif } // Counterpart to `Metadata::SetArenaAllocated()` and // `Metadata::SetReferenceCounted()`, also used by `MemoryManager`. - void ConstructHeap(const Data& data) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(&data)) >= + void ConstructReferenceCounted(const Data& data) { + uintptr_t pointer = absl::bit_cast(std::addressof(data)); + ABSL_ASSERT(absl::countr_zero(pointer) >= 2); // Assert pointer alignment results in at least the 2 least // significant bits being unset. - set_pointer(reinterpret_cast(&data) | - ((reinterpret_cast*>( - reinterpret_cast(&data) + sizeof(uintptr_t)) - ->load(std::memory_order_relaxed) & - kArenaAllocated) == kArenaAllocated - ? kPointerArenaAllocated - : kPointerReferenceCounted)); + set_pointer(pointer | kPointerReferenceCounted); + } + + // Counterpart to `Metadata::SetArenaAllocated()` and + // `Metadata::SetReferenceCounted()`, also used by `MemoryManager`. + void ConstructArenaAllocated(const Data& data) { + uintptr_t pointer = absl::bit_cast(std::addressof(data)); + ABSL_ASSERT(absl::countr_zero(pointer) >= + 2); // Assert pointer alignment results in at least the 2 least + // significant bits being unset. + set_pointer(pointer | kPointerArenaAllocated); } template void ConstructInline(Args&&... args) { + static_assert(sizeof(T) <= kSize); + static_assert(alignof(T) <= kAlign); ::new (buffer()) T(std::forward(args)...); ABSL_ASSERT(IsStoredInline()); } @@ -526,6 +546,62 @@ struct AnyData final { Storage storage; }; +template +struct IsData + : public std::integral_constant> {}; + +template +inline constexpr bool IsDataV = IsData::value; + +template +struct IsDerivedData + : public std::integral_constant< + bool, std::conjunction_v< + std::is_base_of, + std::negation>>>> {}; + +template +inline constexpr bool IsDerivedDataV = IsDerivedData::value; + +template +struct IsInlineData + : public std::integral_constant< + bool, std::conjunction_v, std::is_base_of>> { +}; + +template +inline constexpr bool IsInlineDataV = IsInlineData::value; + +template +struct IsDerivedInlineData + : public std::integral_constant< + bool, + std::conjunction_v< + IsInlineData, IsDerivedData, + std::negation>>>> {}; + +template +inline constexpr bool IsDerivedInlineDataV = IsDerivedInlineData::value; + +template +struct IsHeapData + : public std::integral_constant< + bool, std::conjunction_v, std::is_base_of>> {}; + +template +inline constexpr bool IsHeapDataV = IsHeapData::value; + +template +struct IsDerivedHeapData + : public std::integral_constant< + bool, + std::conjunction_v< + IsHeapData, IsDerivedData, + std::negation>>>> {}; + +template +inline constexpr bool IsDerivedHeapDataV = IsDerivedHeapData::value; + } // namespace base_internal } // namespace cel diff --git a/base/internal/handle.h b/base/internal/handle.h index 8e5a17b09..ba8219a96 100644 --- a/base/internal/handle.h +++ b/base/internal/handle.h @@ -38,10 +38,38 @@ struct HandlePolicy { static_assert(std::is_class_v, "Handles only support classes"); static_assert(!std::is_volatile_v, "Handles do not support volatile"); static_assert(!std::is_const_v, "Handles do not support const"); - static_assert((std::is_base_of_v && !std::is_same_v), - "Handles do not support this type"); + static_assert(IsDerivedDataV, "Handles do not support this type"); }; +// Tag type used to select the correct Handle constructor for constructing +// inline data. +template +struct InPlaceStoredInline { + explicit InPlaceStoredInline() = default; +}; + +template +inline constexpr InPlaceStoredInline kInPlaceStoredInline = + InPlaceStoredInline{}; + +// Tag type used to select the correct Handle constructor for constructing +// from reference counted data. +struct InPlaceReferenceCounted { + explicit InPlaceReferenceCounted() = default; +}; + +inline constexpr InPlaceReferenceCounted kInPlaceReferenceCounted = + InPlaceReferenceCounted{}; + +// Tag type used to select the correct Handle constructor for constructing +// from arena allocated data. +struct InPlaceArenaAllocated { + explicit InPlaceArenaAllocated() = default; +}; + +inline constexpr InPlaceArenaAllocated kInPlaceArenaAllocated = + InPlaceArenaAllocated{}; + } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ diff --git a/base/memory.h b/base/memory.h index b564429db..4e49832b4 100644 --- a/base/memory.h +++ b/base/memory.h @@ -174,13 +174,9 @@ class MemoryManager { template Handle AllocateHandle(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { - static_assert(std::is_base_of_v); - T* pointer; - if (!allocation_only_) { - pointer = new T(std::forward(args)...); - base_internal::Metadata::SetReferenceCounted(*pointer); - } else { - pointer = ::new (Allocate(sizeof(T), alignof(T))) + static_assert(base_internal::IsDerivedHeapDataV); + if (allocation_only_) { + T* pointer = ::new (Allocate(sizeof(T), alignof(T))) T(std::forward(args)...); if constexpr (!std::is_trivially_destructible_v) { if constexpr (base_internal::HasIsDestructorSkippable::value) { @@ -194,14 +190,17 @@ class MemoryManager { } } base_internal::Metadata::SetArenaAllocated(*pointer); + return Handle(base_internal::kInPlaceArenaAllocated, *pointer); } - return Handle(absl::in_place, *pointer); + T* pointer = new T(std::forward(args)...); + base_internal::Metadata::SetReferenceCounted(*pointer); + return Handle(base_internal::kInPlaceReferenceCounted, *pointer); } template UniqueRef AllocateUnique(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { - static_assert(!std::is_base_of_v); + static_assert(!base_internal::IsDataV); T* ptr; if (allocation_only_) { ptr = ::new (Allocate(sizeof(T), alignof(T))) @@ -357,9 +356,8 @@ namespace base_internal { template template -std::enable_if_t, Handle> -HandleFactory::Make(MemoryManager& memory_manager, Args&&... args) { - static_assert(std::is_base_of_v, "T is not derived from Data"); +std::enable_if_t, Handle> HandleFactory::Make( + MemoryManager& memory_manager, Args&&... args) { static_assert(std::is_base_of_v, "F is not derived from T"); #if defined(__cpp_lib_is_pointer_interconvertible) && \ __cpp_lib_is_pointer_interconvertible >= 201907L diff --git a/base/type.h b/base/type.h index 123a4c62d..764e9cd9a 100644 --- a/base/type.h +++ b/base/type.h @@ -24,7 +24,6 @@ #include "absl/base/optimization.h" #include "absl/hash/hash.h" #include "absl/strings/string_view.h" -#include "absl/utility/utility.h" #include "base/handle.h" #include "base/internal/data.h" #include "base/internal/type.h" // IWYU pragma: export @@ -146,14 +145,22 @@ class TypeMetadata final { class TypeHandle final { public: + using base_type = Type; + TypeHandle() = default; template - explicit TypeHandle(absl::in_place_type_t, Args&&... args) { + explicit TypeHandle(InPlaceStoredInline, Args&&... args) { data_.ConstructInline(std::forward(args)...); } - explicit TypeHandle(const Type& type) { data_.ConstructHeap(type); } + explicit TypeHandle(InPlaceArenaAllocated, Type& arg) { + data_.ConstructArenaAllocated(arg); + } + + explicit TypeHandle(InPlaceReferenceCounted, Type& arg) { + data_.ConstructReferenceCounted(arg); + } TypeHandle(const TypeHandle& other) { CopyFrom(other); } diff --git a/base/value.h b/base/value.h index 9155e821f..4af859689 100644 --- a/base/value.h +++ b/base/value.h @@ -53,7 +53,7 @@ class Value : public base_internal::Data { static const Value& Cast(const Value& value) { return value; } // Returns the kind of the value. This is equivalent to `type().kind()` but - // faster in many scenarios. As such it should be preffered when only the kind + // faster in many scenarios. As such it should be preferred when only the kind // is required. Kind kind() const { return base_internal::Metadata::Kind(*this); } @@ -135,14 +135,22 @@ class ValueMetadata final { class ValueHandle final { public: + using base_type = Value; + ValueHandle() = default; template - explicit ValueHandle(absl::in_place_type_t in_place_type, Args&&... args) { + explicit ValueHandle(InPlaceStoredInline, Args&&... args) { data_.ConstructInline(std::forward(args)...); } - explicit ValueHandle(const Value& value) { data_.ConstructHeap(value); } + explicit ValueHandle(InPlaceArenaAllocated, Value& arg) { + data_.ConstructArenaAllocated(arg); + } + + explicit ValueHandle(InPlaceReferenceCounted, Value& arg) { + data_.ConstructReferenceCounted(arg); + } ValueHandle(const ValueHandle& other) { CopyFrom(other); } From 9ba7ad80a043b5700cf8303052a8a01485c20544 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 5 May 2023 18:31:34 +0000 Subject: [PATCH 243/303] More cleanup of handle infrastructure PiperOrigin-RevId: 529770702 --- base/internal/data.h | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/base/internal/data.h b/base/internal/data.h index eb1c9121c..b18a5e434 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -260,17 +260,17 @@ class Metadata final { static void Ref(const Data& data) { ABSL_ASSERT(IsReferenceCounted(data)); - const auto count = - (ReferenceCount(data).fetch_add(1, std::memory_order_relaxed)) & - kReferenceCountMask; + const auto count = (ReferenceCount(const_cast(data)) + .fetch_add(1, std::memory_order_relaxed)) & + kReferenceCountMask; ABSL_ASSERT(count > 0 && count < kReferenceCountMax); } ABSL_MUST_USE_RESULT static bool Unref(const Data& data) { ABSL_ASSERT(IsReferenceCounted(data)); - const auto count = - (ReferenceCount(data).fetch_sub(1, std::memory_order_seq_cst)) & - kReferenceCountMask; + const auto count = (ReferenceCount(const_cast(data)) + .fetch_sub(1, std::memory_order_seq_cst)) & + kReferenceCountMask; ABSL_ASSERT(count > 0 && count < kReferenceCountMax); return count == 1; } @@ -284,7 +284,7 @@ class Metadata final { static bool IsUnique(const Data& data) { ABSL_ASSERT(IsReferenceCounted(data)); - return ((ReferenceCount(data).fetch_add(1, std::memory_order_acquire)) & + return (ReferenceCount(data).load(std::memory_order_acquire) & kReferenceCountMask) == 1; } @@ -294,12 +294,12 @@ class Metadata final { } // Used by `MemoryManager::New()`. - static void SetArenaAllocated(const Data& data) { + static void SetArenaAllocated(Data& data) { ReferenceCount(data).fetch_or(kArenaAllocated, std::memory_order_relaxed); } // Used by `MemoryManager::New()`. - static void SetReferenceCounted(const Data& data) { + static void SetReferenceCounted(Data& data) { ReferenceCount(data).fetch_or(kReferenceCounted, std::memory_order_relaxed); } @@ -326,16 +326,23 @@ class Metadata final { static uintptr_t VirtualPointer(const Data& data) { // The vptr, or equivalent, is stored at offset 0. Inform the compiler that // `data` is aligned to at least `uintptr_t`. - return *reinterpret_cast(&data); + return *absl::bit_cast(std::addressof(data)); } - static std::atomic& ReferenceCount(const Data& data) { + static const std::atomic& ReferenceCount(const Data& data) { // For arena allocated and reference counted, the reference count // immediately follows the vptr, or equivalent, at offset 0. So its offset - // is `sizeof(uintptr_t)`. Inform the compiler that `data` is aligned to at - // least `uintptr_t` and `std::atomic`. - return *reinterpret_cast*>(const_cast( - reinterpret_cast(&data) + sizeof(uintptr_t))); + // is `sizeof(uintptr_t)`. + return *absl::bit_cast*>( + absl::bit_cast(std::addressof(data)) + sizeof(uintptr_t)); + } + + static std::atomic& ReferenceCount(Data& data) { + // For arena allocated and reference counted, the reference count + // immediately follows the vptr, or equivalent, at offset 0. So its offset + // is `sizeof(uintptr_t)`. + return const_cast&>( + ReferenceCount(static_cast(data))); } Metadata() = delete; @@ -405,8 +412,8 @@ struct AnyData final { Kind kind_heap() const { return static_cast( - ((reinterpret_cast*>((pointer() & kPointerMask) + - sizeof(uintptr_t)) + ((absl::bit_cast*>((pointer() & kPointerMask) + + sizeof(uintptr_t)) ->load(std::memory_order_relaxed)) >> kKindShift) & kKindMask); @@ -515,6 +522,7 @@ struct AnyData final { 2); // Assert pointer alignment results in at least the 2 least // significant bits being unset. set_pointer(pointer | kPointerReferenceCounted); + ABSL_ASSERT(IsReferenceCounted()); } // Counterpart to `Metadata::SetArenaAllocated()` and @@ -525,6 +533,7 @@ struct AnyData final { 2); // Assert pointer alignment results in at least the 2 least // significant bits being unset. set_pointer(pointer | kPointerArenaAllocated); + ABSL_ASSERT(IsArenaAllocated()); } template From 7105bce144302f952db9e88b76b4fb74a74a6014 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 5 May 2023 18:37:31 +0000 Subject: [PATCH 244/303] Deduplicate dispatching logic for `StructType` and `StructValue` PiperOrigin-RevId: 529772225 --- base/types/struct_type.cc | 44 +++++++++++------------------- base/values/struct_value.cc | 54 +++++++++++-------------------------- 2 files changed, 31 insertions(+), 67 deletions(-) diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc index eee785d96..874139518 100644 --- a/base/types/struct_type.cc +++ b/base/types/struct_type.cc @@ -15,63 +15,49 @@ #include "base/types/struct_type.h" #include -#include #include "absl/base/macros.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" -#include "base/internal/message_wrapper.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(StructType); +#define CEL_INTERNAL_STRUCT_TYPE_DISPATCH(method, ...) \ + base_internal::Metadata::IsStoredInline(*this) \ + ? static_cast(*this).method( \ + __VA_ARGS__) \ + : static_cast(*this).method( \ + __VA_ARGS__) + absl::string_view StructType::name() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this)->name(); - } - return static_cast(this)->name(); + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(name); } std::string StructType::DebugString() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->DebugString(); - } - return static_cast(this) - ->DebugString(); + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(DebugString); } internal::TypeInfo StructType::TypeId() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this)->TypeId(); - } - return static_cast(this)->TypeId(); + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(TypeId); } absl::StatusOr> StructType::FindFieldByName( TypeManager& type_manager, absl::string_view name) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->FindFieldByName(type_manager, name); - } - return static_cast(this) - ->FindFieldByName(type_manager, name); + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(FindFieldByName, type_manager, name); } absl::StatusOr> StructType::FindFieldByNumber( TypeManager& type_manager, int64_t number) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->FindFieldByNumber(type_manager, number); - } - return static_cast(this) - ->FindFieldByNumber(type_manager, number); + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(FindFieldByNumber, type_manager, + number); } +#undef CEL_INTERNAL_STRUCT_TYPE_DISPATCH + struct StructType::FindFieldVisitor final { const StructType& struct_type; TypeManager& type_manager; diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc index a669043b3..ba6b16a4f 100644 --- a/base/values/struct_value.cc +++ b/base/values/struct_value.cc @@ -33,69 +33,47 @@ namespace cel { CEL_INTERNAL_VALUE_IMPL(StructValue); +#define CEL_INTERNAL_STRUCT_VALUE_DISPATCH(method, ...) \ + base_internal::Metadata::IsStoredInline(*this) \ + ? static_cast(*this).method( \ + __VA_ARGS__) \ + : static_cast(*this).method( \ + __VA_ARGS__) + Handle StructValue::type() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this)->type(); - } - return static_cast(this)->type(); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(type); } std::string StructValue::DebugString() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->DebugString(); - } - return static_cast(this) - ->DebugString(); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(DebugString); } absl::StatusOr> StructValue::GetFieldByName( const GetFieldContext& context, absl::string_view name) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->GetFieldByName(context, name); - } - return static_cast(this) - ->GetFieldByName(context, name); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(GetFieldByName, context, name); } absl::StatusOr> StructValue::GetFieldByNumber( const GetFieldContext& context, int64_t number) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->GetFieldByNumber(context, number); - } - return static_cast(this) - ->GetFieldByNumber(context, number); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(GetFieldByNumber, context, number); } absl::StatusOr StructValue::HasFieldByName(const HasFieldContext& context, absl::string_view name) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->HasFieldByName(context, name); - } - return static_cast(this) - ->HasFieldByName(context, name); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(HasFieldByName, context, name); } absl::StatusOr StructValue::HasFieldByNumber( const HasFieldContext& context, int64_t number) const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this) - ->HasFieldByNumber(context, number); - } - return static_cast(this) - ->HasFieldByNumber(context, number); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(HasFieldByNumber, context, number); } internal::TypeInfo StructValue::TypeId() const { - if (base_internal::Metadata::IsStoredInline(*this)) { - return static_cast(this)->TypeId(); - } - return static_cast(this)->TypeId(); + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(TypeId); } +#undef CEL_INTERNAL_STRUCT_VALUE_DISPATCH + struct StructValue::GetFieldVisitor final { const StructValue& struct_value; const GetFieldContext& context; From 3216dd6da130c8f888508381fac053ae30638b87 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 5 May 2023 21:31:58 +0000 Subject: [PATCH 245/303] Make `GetXByName` and `GetXByNumber` public PiperOrigin-RevId: 529815909 --- base/types/enum_type.h | 6 +++--- base/types/struct_type.h | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/base/types/enum_type.h b/base/types/enum_type.h index 234f2e1c0..7c81f61f6 100644 --- a/base/types/enum_type.h +++ b/base/types/enum_type.h @@ -85,9 +85,6 @@ class EnumType : public Type, public base_internal::HeapData { // returned. absl::StatusOr> FindConstant(ConstantId id) const; - protected: - EnumType(); - // Called by FindConstant. virtual absl::StatusOr> FindConstantByName( absl::string_view name) const = 0; @@ -96,6 +93,9 @@ class EnumType : public Type, public base_internal::HeapData { virtual absl::StatusOr> FindConstantByNumber( int64_t number) const = 0; + protected: + EnumType(); + private: friend internal::TypeInfo base_internal::GetEnumTypeTypeId( const EnumType& enum_type); diff --git a/base/types/struct_type.h b/base/types/struct_type.h index 727800c63..f76fd22ad 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -92,7 +92,6 @@ class StructType : public Type { absl::StatusOr> FindField(TypeManager& type_manager, FieldId id) const; - protected: // Called by FindField. absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const; From 63cfcecd1438348b81f186900d865cced99f66c1 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 8 May 2023 16:00:25 +0000 Subject: [PATCH 246/303] Internal tool change PiperOrigin-RevId: 530316955 --- eval/testutil/test_message.proto | 1 + 1 file changed, 1 insertion(+) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 1501236f5..bd42dfb6e 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -72,6 +72,7 @@ message TestMessage { google.protobuf.Duration duration_value = 301; google.protobuf.Timestamp timestamp_value = 302; google.protobuf.Struct struct_value = 303; + // TODO(issues/5): Test null_value with variable bindings. google.protobuf.Value value_value = 304; google.protobuf.Int64Value int64_wrapper_value = 305; google.protobuf.Int32Value int32_wrapper_value = 306; From 2ba28a83be6a41647fa779edae2be6093589742a Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 8 May 2023 16:14:34 +0000 Subject: [PATCH 247/303] Fix typo in comment PiperOrigin-RevId: 530320713 --- eval/public/cel_value.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index a88a9d7fe..57fff2ca9 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -52,7 +52,7 @@ namespace google::api::expr::runtime { using CelError = absl::Status; -// Break cyclic depdendencies for container types. +// Break cyclic dependencies for container types. class CelList; class CelMap; class LegacyTypeAdapter; From 1f181f64492ad39dc387fa0d7ab990b9bdae3100 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 8 May 2023 16:33:45 +0000 Subject: [PATCH 248/303] Add ability to iterate over enumeration values PiperOrigin-RevId: 530325851 --- base/BUILD | 1 - base/type_test.cc | 9 ++- base/types/enum_type.cc | 13 +++- base/types/enum_type.h | 64 ++++++++++++------ base/value_test.cc | 8 +++ extensions/protobuf/BUILD | 1 + extensions/protobuf/enum_type.cc | 37 +++++++++++ extensions/protobuf/enum_type.h | 14 ++-- extensions/protobuf/enum_type_test.cc | 93 ++++++++++++++++++++++++--- 9 files changed, 204 insertions(+), 36 deletions(-) diff --git a/base/BUILD b/base/BUILD index e493b09a6..355d26817 100644 --- a/base/BUILD +++ b/base/BUILD @@ -177,7 +177,6 @@ cc_library( ":kind", ":memory_manager", "//base/internal:data", - "//base/internal:message_wrapper", "//base/internal:type", "//internal:casts", "//internal:no_destructor", diff --git a/base/type_test.cc b/base/type_test.cc index 2fc103c05..ac3b7a10a 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -36,7 +36,6 @@ namespace { using testing::ElementsAre; using testing::Eq; using cel::internal::IsOkAndHolds; -using cel::internal::StatusIs; enum class TestEnum { kValue1 = 1, @@ -49,6 +48,14 @@ class TestEnumType final : public EnumType { absl::string_view name() const override { return "test_enum.TestEnum"; } + size_t constant_count() const override { return 2; } + + absl::StatusOr> NewConstantIterator( + MemoryManager& memory_manager) const override { + return absl::UnimplementedError( + "EnumType::NewConstantIterator is unimplemented"); + } + protected: absl::StatusOr> FindConstantByName( absl::string_view name) const override { diff --git a/base/types/enum_type.cc b/base/types/enum_type.cc index d412855a6..0e3a526d0 100644 --- a/base/types/enum_type.cc +++ b/base/types/enum_type.cc @@ -14,12 +14,11 @@ #include "base/types/enum_type.h" -#include - #include "absl/base/macros.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "internal/status_macros.h" namespace cel { @@ -50,4 +49,14 @@ absl::StatusOr> EnumType::FindConstant( return absl::visit(FindConstantVisitor{*this}, id.data_); } +absl::StatusOr EnumType::ConstantIterator::NextName() { + CEL_ASSIGN_OR_RETURN(auto constant, Next()); + return constant.name; +} + +absl::StatusOr EnumType::ConstantIterator::NextNumber() { + CEL_ASSIGN_OR_RETURN(auto constant, Next()); + return constant.number; +} + } // namespace cel diff --git a/base/types/enum_type.h b/base/types/enum_type.h index 7c81f61f6..e74cba3e0 100644 --- a/base/types/enum_type.h +++ b/base/types/enum_type.h @@ -19,6 +19,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -26,6 +27,7 @@ #include "absl/types/variant.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/type.h" #include "internal/rtti.h" @@ -75,23 +77,33 @@ class EnumType : public Type, public base_internal::HeapData { Kind kind() const { return kKind; } - virtual absl::string_view name() const = 0; + virtual absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; std::string DebugString() const { return std::string(name()); } + virtual size_t constant_count() const = 0; + // Find the constant definition for the given identifier. If the constant does // not exist, an OK status and empty optional is returned. If the constant // exists, an OK status and the constant is returned. Otherwise an error is // returned. - absl::StatusOr> FindConstant(ConstantId id) const; + absl::StatusOr> FindConstant(ConstantId id) const + ABSL_ATTRIBUTE_LIFETIME_BOUND; // Called by FindConstant. virtual absl::StatusOr> FindConstantByName( - absl::string_view name) const = 0; + absl::string_view name) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; // Called by FindConstant. virtual absl::StatusOr> FindConstantByNumber( - int64_t number) const = 0; + int64_t number) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + class ConstantIterator; + + // Returns an iterator which can iterate over all the constants defined by + // this enumeration. The order with which iteration occurs is undefined. + virtual absl::StatusOr> NewConstantIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; protected: EnumType(); @@ -116,6 +128,36 @@ class EnumType : public Type, public base_internal::HeapData { virtual internal::TypeInfo TypeId() const = 0; }; +// Constant describes a single value in an enumeration. All fields are valid so +// long as EnumType is valid. +struct EnumType::Constant final { + Constant(absl::string_view name, int64_t number, const void* hint = nullptr) + : name(name), number(number), hint(hint) {} + + // The unqualified enumeration value name. + absl::string_view name; + // The enumeration value number. + int64_t number; + // Some implementation-specific data that can be laundered to the value + // implementation for this type to perform optimizations. + const void* hint = nullptr; +}; + +class EnumType::ConstantIterator { + public: + using Constant = EnumType::Constant; + + virtual ~ConstantIterator() = default; + + ABSL_MUST_USE_RESULT virtual bool HasNext() = 0; + + virtual absl::StatusOr Next() = 0; + + virtual absl::StatusOr NextName(); + + virtual absl::StatusOr NextNumber(); +}; + // CEL_DECLARE_ENUM_TYPE declares `enum_type` as an enumeration type. It must be // part of the class definition of `enum_type`. // @@ -140,20 +182,6 @@ class EnumType : public Type, public base_internal::HeapData { #define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ CEL_INTERNAL_IMPLEMENT_TYPE(Enum, enum_type) -struct EnumType::Constant final { - explicit Constant(absl::string_view name, int64_t number, - const void* hint = nullptr) - : name(name), number(number) {} - - // The unqualified enumeration value name. - absl::string_view name; - // The enumeration value number. - int64_t number; - // Some implementation-specific data that can be laundered to the value - // implementation for this type to perform optimizations. - const void* hint = nullptr; -}; - CEL_INTERNAL_TYPE_DECL(EnumType); namespace base_internal { diff --git a/base/value_test.cc b/base/value_test.cc index a1a0f1065..fcf982e6d 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -60,6 +60,14 @@ class TestEnumType final : public EnumType { absl::string_view name() const override { return "test_enum.TestEnum"; } + size_t constant_count() const override { return 2; } + + absl::StatusOr> NewConstantIterator( + MemoryManager& memory_manager) const override { + return absl::UnimplementedError( + "EnumType::NewConstantIterator is unimplemented"); + } + protected: absl::StatusOr> FindConstantByName( absl::string_view name) const override { diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 682581945..9cfcc7d87 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -125,6 +125,7 @@ cc_test( "//base/testing:type_matchers", "//extensions/protobuf/internal:testing", "//internal:testing", + "@com_google_absl//absl/status", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/extensions/protobuf/enum_type.cc b/extensions/protobuf/enum_type.cc index 11e0b8e39..56de20399 100644 --- a/extensions/protobuf/enum_type.cc +++ b/extensions/protobuf/enum_type.cc @@ -25,6 +25,33 @@ namespace cel::extensions { +namespace { + +class ProtoEnumTypeConstantIterator final : public EnumType::ConstantIterator { + public: + explicit ProtoEnumTypeConstantIterator( + const google::protobuf::EnumDescriptor& descriptor) + : descriptor_(descriptor) {} + + bool HasNext() override { return index_ < descriptor_.value_count(); } + + absl::StatusOr Next() override { + if (ABSL_PREDICT_FALSE(index_ >= descriptor_.value_count())) { + return absl::FailedPreconditionError( + "EnumType::ConstantIterator::Next() called when " + "EnumType::ConstantIterator::HasNext() returns false"); + } + const auto* value = descriptor_.value(index_++); + return Constant(value->name(), value->number(), value); + } + + private: + const google::protobuf::EnumDescriptor& descriptor_; + int index_ = 0; +}; + +} // namespace + absl::StatusOr> ProtoEnumType::Resolve( TypeManager& type_manager, const google::protobuf::EnumDescriptor& descriptor) { CEL_ASSIGN_OR_RETURN(auto type, @@ -42,6 +69,10 @@ absl::StatusOr> ProtoEnumType::Resolve( return std::move(type).value().As(); } +size_t ProtoEnumType::constant_count() const { + return descriptor().value_count(); +} + absl::StatusOr> ProtoEnumType::FindConstantByName(absl::string_view name) const { const auto* value_desc = descriptor().FindValueByName(name); @@ -68,4 +99,10 @@ ProtoEnumType::FindConstantByNumber(int64_t number) const { return Constant{value_desc->name(), value_desc->number(), value_desc}; } +absl::StatusOr> +ProtoEnumType::NewConstantIterator(MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, + descriptor()); +} + } // namespace cel::extensions diff --git a/extensions/protobuf/enum_type.h b/extensions/protobuf/enum_type.h index 768e9b417..47abb2781 100644 --- a/extensions/protobuf/enum_type.h +++ b/extensions/protobuf/enum_type.h @@ -32,9 +32,9 @@ class ProtoTypeProvider; class ProtoEnumType final : public EnumType { public: static bool Is(const Type& type) { - return type.kind() == Kind::kEnum && - cel::base_internal::GetEnumTypeTypeId(static_cast( - type)) == cel::internal::TypeId(); + return EnumType::Is(type) && cel::base_internal::GetEnumTypeTypeId( + static_cast(type)) == + cel::internal::TypeId(); } using EnumType::Is; @@ -46,9 +46,8 @@ class ProtoEnumType final : public EnumType { absl::string_view name() const override { return descriptor().full_name(); } - const google::protobuf::EnumDescriptor& descriptor() const { return *descriptor_; } + size_t constant_count() const override; - protected: // Called by FindField. absl::StatusOr> FindConstantByName( absl::string_view name) const override; @@ -57,6 +56,11 @@ class ProtoEnumType final : public EnumType { absl::StatusOr> FindConstantByNumber( int64_t number) const override; + absl::StatusOr> NewConstantIterator( + MemoryManager& memory_manager) const override; + + const google::protobuf::EnumDescriptor& descriptor() const { return *descriptor_; } + private: friend class ProtoType; friend class ProtoTypeProvider; diff --git a/extensions/protobuf/enum_type_test.cc b/extensions/protobuf/enum_type_test.cc index 1935ce321..eee98dcc4 100644 --- a/extensions/protobuf/enum_type_test.cc +++ b/extensions/protobuf/enum_type_test.cc @@ -14,11 +14,16 @@ #include "extensions/protobuf/enum_type.h" +#include + #include "google/protobuf/type.pb.h" +#include "absl/status/status.h" +#include "base/internal/memory_manager_testing.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" +#include "extensions/protobuf/internal/testing.h" #include "extensions/protobuf/type.h" #include "extensions/protobuf/type_provider.h" #include "internal/testing.h" @@ -27,8 +32,12 @@ namespace cel::extensions { namespace { -TEST(ProtoEnumType, CreateStatically) { - TypeFactory type_factory(MemoryManager::Global()); +using cel::internal::StatusIs; + +using ProtoEnumTypeTest = ProtoTest<>; + +TEST_P(ProtoEnumTypeTest, CreateStatically) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -42,8 +51,8 @@ TEST(ProtoEnumType, CreateStatically) { google::protobuf::GetEnumDescriptor()); } -TEST(ProtoEnumType, CreateDynamically) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoEnumTypeTest, CreateDynamically) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -59,8 +68,8 @@ TEST(ProtoEnumType, CreateDynamically) { google::protobuf::GetEnumDescriptor()); } -TEST(ProtoEnumType, FindConstantByName) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoEnumTypeTest, FindConstantByName) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -73,8 +82,8 @@ TEST(ProtoEnumType, FindConstantByName) { EXPECT_EQ(constant->name, "TYPE_STRING"); } -TEST(ProtoEnumType, FindConstantByNumber) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoEnumTypeTest, FindConstantByNumber) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -87,5 +96,71 @@ TEST(ProtoEnumType, FindConstantByNumber) { EXPECT_EQ(constant->name, "TYPE_STRING"); } +TEST_P(ProtoEnumTypeTest, ConstantCount) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN( + auto type, + ProtoType::Resolve(type_manager)); + EXPECT_EQ(type->constant_count(), + google::protobuf::GetEnumDescriptor() + ->value_count()); +} + +TEST_P(ProtoEnumTypeTest, NewConstantIteratorNames) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN( + auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, + type->NewConstantIterator(memory_manager())); + std::set actual_names; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto name, iterator->NextName()); + actual_names.insert(name); + } + EXPECT_THAT(iterator->Next(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_names; + const auto* const descriptor = + google::protobuf::GetEnumDescriptor(); + for (int index = 0; index < descriptor->value_count(); index++) { + expected_names.insert(descriptor->value(index)->name()); + } + EXPECT_EQ(actual_names, expected_names); +} + +TEST_P(ProtoEnumTypeTest, NewConstantIteratorNumbers) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN( + auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, + type->NewConstantIterator(memory_manager())); + std::set actual_names; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto number, iterator->NextNumber()); + actual_names.insert(number); + } + EXPECT_THAT(iterator->Next(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_names; + const auto* const descriptor = + google::protobuf::GetEnumDescriptor(); + for (int index = 0; index < descriptor->value_count(); index++) { + expected_names.insert(descriptor->value(index)->number()); + } + EXPECT_EQ(actual_names, expected_names); +} + +INSTANTIATE_TEST_SUITE_P(ProtoEnumTypeTest, ProtoEnumTypeTest, + cel::base_internal::MemoryManagerTestModeAll(), + cel::base_internal::MemoryManagerTestModeTupleName); + } // namespace } // namespace cel::extensions From 0b5a3085ba26a0daa414281145ef07ae9d14077e Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 8 May 2023 17:27:29 +0000 Subject: [PATCH 249/303] Add expression builder benchmark for precompiling regular expressions. PiperOrigin-RevId: 530340938 --- eval/tests/BUILD | 1 + .../expression_builder_benchmark_test.cc | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 97c72ec9a..beff303b5 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -142,6 +142,7 @@ cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index 64f9ad693..4317b0e49 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -17,6 +17,7 @@ #include #include +#include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/text_format.h" #include "absl/container/flat_hash_set.h" @@ -162,6 +163,55 @@ BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kUpdatedFoldConstants); +void RegexPrecompilationBench(bool enabled, benchmark::State& state) { + auto param = static_cast(state.range(0)); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( + input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || + input_str.matches(r'10(\.[0-9]{1,3}){3}') + )cel")); + + // Fake a checked expression with enough reference information for the expr + // builder to identify the regex as optimize-able. + CheckedExpr checked_expr; + checked_expr.mutable_expr()->Swap(expr.mutable_expr()); + checked_expr.mutable_source_info()->Swap(expr.mutable_source_info()); + (*checked_expr.mutable_reference_map())[2].add_overload_id("matches_string"); + (*checked_expr.mutable_reference_map())[11].add_overload_id("matches_string"); + + google::protobuf::Arena arena; + InterpreterOptions options = OptionsForParam(param, arena); + options.enable_regex_precompilation = enabled; + + auto builder = CreateCelExpressionBuilder(options); + auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); + ASSERT_OK(reg_status); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(auto expression, + builder->CreateExpression(&checked_expr)); + arena.Reset(); + } +} + +void BM_RegexPrecompilationDisabled(benchmark::State& state) { + RegexPrecompilationBench(false, state); +} + +BENCHMARK(BM_RegexPrecompilationDisabled) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kUpdatedFoldConstants); + +void BM_RegexPrecompilationEnabled(benchmark::State& state) { + RegexPrecompilationBench(true, state); +} + +BENCHMARK(BM_RegexPrecompilationEnabled) + ->Arg(BenchmarkParam::kDefault) + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kUpdatedFoldConstants); + void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); auto size = state.range(1); From baee8df1ffd771cb9726c397e198e9ee5bab38ce Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 May 2023 15:01:49 +0000 Subject: [PATCH 250/303] Add ability to iterate over containers PiperOrigin-RevId: 530610320 --- base/value_test.cc | 139 +++++++++++++++++++++++++++- base/values/list_value.cc | 100 +++++++++++++++++++++ base/values/list_value.h | 42 ++++++++- base/values/map_value.cc | 165 ++++++++++++++++++++++++++++++++++ base/values/map_value.h | 41 +++++++++ eval/internal/interop_test.cc | 117 +++++++++++++++++++++++- 6 files changed, 600 insertions(+), 4 deletions(-) diff --git a/base/value_test.cc b/base/value_test.cc index fcf982e6d..569daac91 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,7 @@ #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "base/internal/memory_manager_testing.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type.h" #include "base/type_factory.h" #include "base/type_manager.h" @@ -280,6 +281,36 @@ class TestListValue final : public CEL_LIST_VALUE_CLASS { CEL_IMPLEMENT_LIST_VALUE(TestListValue); +class TestMapKeysListValue final : public CEL_LIST_VALUE_CLASS { + public: + explicit TestMapKeysListValue(const Handle& type, + std::vector elements) + : CEL_LIST_VALUE_CLASS(type), elements_(std::move(elements)) {} + + size_t size() const override { return elements_.size(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const override { + if (index >= size()) { + return absl::OutOfRangeError(""); + } + return context.value_factory().CreateStringValue(elements_[index]); + } + + std::string DebugString() const override { + return absl::StrCat("[", absl::StrJoin(elements_, ", "), "]"); + } + + const std::vector& value() const { return elements_; } + + private: + std::vector elements_; + + CEL_DECLARE_LIST_VALUE(TestMapKeysListValue); +}; + +CEL_IMPLEMENT_LIST_VALUE(TestMapKeysListValue); + class TestMapValue final : public CEL_MAP_VALUE_CLASS { public: explicit TestMapValue(const Handle& type, @@ -326,7 +357,17 @@ class TestMapValue final : public CEL_MAP_VALUE_CLASS { absl::StatusOr> ListKeys( const ListKeysContext& context) const override { - return absl::UnimplementedError("MapValue::ListKeys is not implemented"); + CEL_ASSIGN_OR_RETURN( + auto list_type, + context.value_factory().type_factory().CreateListType( + context.value_factory().type_factory().GetStringType())); + std::vector keys; + keys.reserve(entries_.size()); + for (const auto& entry : entries_) { + keys.push_back(entry.first); + } + return context.value_factory().CreateListValue( + std::move(list_type), std::move(keys)); } const std::map& value() const { return entries_; } @@ -2192,6 +2233,52 @@ TEST_P(ListValueTest, Get) { StatusIs(absl::StatusCode::kOutOfRange)); } +TEST_P(ListValueTest, NewIteratorIndices) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto list_value, + value_factory.CreateListValue( + list_type, std::vector{0, 1, 2})); + ASSERT_OK_AND_ASSIGN(auto iterator, + list_value->NewIterator(memory_manager())); + std::set actual_indices; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto index, iterator->NextIndex(ListValue::GetContext(value_factory))); + actual_indices.insert(index); + } + EXPECT_THAT(iterator->NextIndex(ListValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_indices = {0, 1, 2}; + EXPECT_EQ(actual_indices, expected_indices); +} + +TEST_P(ListValueTest, NewIteratorValues) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto list_value, + value_factory.CreateListValue( + list_type, std::vector{3, 4, 5})); + ASSERT_OK_AND_ASSIGN(auto iterator, + list_value->NewIterator(memory_manager())); + std::set actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, iterator->NextValue(ListValue::GetContext(value_factory))); + actual_values.insert(value->As().value()); + } + EXPECT_THAT(iterator->NextValue(ListValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_values = {3, 4, 5}; + EXPECT_EQ(actual_values, expected_values); +} + INSTANTIATE_TEST_SUITE_P(ListValueTest, ListValueTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeTupleName); @@ -2300,6 +2387,54 @@ TEST_P(MapValueTest, GetAndHas) { IsOkAndHolds(false)); } +TEST_P(MapValueTest, NewIteratorKeys) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto map_value, + value_factory.CreateMapValue( + map_type, std::map{ + {"foo", 1}, {"bar", 2}, {"baz", 3}})); + ASSERT_OK_AND_ASSIGN(auto iterator, map_value->NewIterator(memory_manager())); + std::set actual_keys; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto key, iterator->NextKey(MapValue::GetContext(value_factory))); + actual_keys.insert(key->As().ToString()); + } + EXPECT_THAT(iterator->NextKey(MapValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_keys = {"foo", "bar", "baz"}; + EXPECT_EQ(actual_keys, expected_keys); +} + +TEST_P(MapValueTest, NewIteratorValues) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto map_value, + value_factory.CreateMapValue( + map_type, std::map{ + {"foo", 1}, {"bar", 2}, {"baz", 3}})); + ASSERT_OK_AND_ASSIGN(auto iterator, map_value->NewIterator(memory_manager())); + std::set actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, iterator->NextValue(MapValue::GetContext(value_factory))); + actual_values.insert(value->As().value()); + } + EXPECT_THAT(iterator->NextValue(MapValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_values = {1, 2, 3}; + EXPECT_EQ(actual_values, expected_values); +} + INSTANTIATE_TEST_SUITE_P(MapValueTest, MapValueTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeTupleName); diff --git a/base/values/list_value.cc b/base/values/list_value.cc index a13baa603..28f23a61b 100644 --- a/base/values/list_value.cc +++ b/base/values/list_value.cc @@ -19,12 +19,14 @@ #include #include "absl/base/macros.h" +#include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "base/handle.h" #include "base/internal/data.h" #include "base/type.h" #include "base/types/list_type.h" #include "internal/rtti.h" +#include "internal/status_macros.h" namespace cel { @@ -58,14 +60,102 @@ absl::StatusOr> ListValue::Get(const GetContext& context, return CEL_INTERNAL_LIST_VALUE_DISPATCH(Get, context, index); } +absl::StatusOr> ListValue::NewIterator( + MemoryManager& memory_manager) const { + return CEL_INTERNAL_LIST_VALUE_DISPATCH(NewIterator, memory_manager); +} + internal::TypeInfo ListValue::TypeId() const { return CEL_INTERNAL_LIST_VALUE_DISPATCH(TypeId); } #undef CEL_INTERNAL_LIST_VALUE_DISPATCH +absl::StatusOr ListValue::Iterator::NextIndex( + const ListValue::GetContext& context) { + CEL_ASSIGN_OR_RETURN(auto element, Next(context)); + return element.index; +} + +absl::StatusOr> ListValue::Iterator::NextValue( + const ListValue::GetContext& context) { + CEL_ASSIGN_OR_RETURN(auto element, Next(context)); + return std::move(element.value); +} + namespace base_internal { +namespace { + +class LegacyListValueIterator final : public ListValue::Iterator { + public: + explicit LegacyListValueIterator(uintptr_t impl) + : impl_(impl), size_(LegacyListValueSize(impl_)) {} + + bool HasNext() override { return index_ < size_; } + + absl::StatusOr Next(const ListValue::GetContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ListValue::Iterator::Next() called when " + "ListValue::Iterator::HasNext() returns false"); + } + CEL_ASSIGN_OR_RETURN( + auto value, LegacyListValueGet(impl_, context.value_factory(), index_)); + return Element(index_++, std::move(value)); + } + + absl::StatusOr NextIndex( + const ListValue::GetContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ListValue::Iterator::Next() called when " + "ListValue::Iterator::HasNext() returns false"); + } + return index_++; + } + + private: + const uintptr_t impl_; + const size_t size_; + size_t index_ = 0; +}; + +class AbstractListValueIterator final : public ListValue::Iterator { + public: + explicit AbstractListValueIterator(const AbstractListValue* value) + : value_(value), size_(value_->size()) {} + + bool HasNext() override { return index_ < size_; } + + absl::StatusOr Next(const ListValue::GetContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ListValue::Iterator::Next() called when " + "ListValue::Iterator::HasNext() returns false"); + } + CEL_ASSIGN_OR_RETURN(auto value, value_->Get(context, index_)); + return Element(index_++, std::move(value)); + } + + absl::StatusOr NextIndex( + const ListValue::GetContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= size_)) { + return absl::FailedPreconditionError( + "ListValue::Iterator::Next() called when " + "ListValue::Iterator::HasNext() returns false"); + } + return index_++; + } + + private: + const AbstractListValue* const value_; + const size_t size_; + size_t index_ = 0; +}; + +} // namespace + Handle LegacyListValue::type() const { return HandleFactory::Make(); } @@ -76,6 +166,11 @@ size_t LegacyListValue::size() const { return LegacyListValueSize(impl_); } bool LegacyListValue::empty() const { return LegacyListValueEmpty(impl_); } +absl::StatusOr> LegacyListValue::NewIterator( + MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, impl_); +} + absl::StatusOr> LegacyListValue::Get(const GetContext& context, size_t index) const { return LegacyListValueGet(impl_, context.value_factory(), index); @@ -88,6 +183,11 @@ AbstractListValue::AbstractListValue(Handle type) reinterpret_cast(static_cast(this))); } +absl::StatusOr> AbstractListValue::NewIterator( + MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, this); +} + } // namespace base_internal } // namespace cel diff --git a/base/values/list_value.h b/base/values/list_value.h index 90590957e..513757036 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -25,10 +25,10 @@ #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" -#include "base/allocator.h" #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/owner.h" #include "base/type.h" #include "base/types/list_type.h" @@ -87,6 +87,19 @@ class ListValue : public Value { absl::StatusOr> Get(const GetContext& context, size_t index) const; + struct Element final { + Element(size_t index, Handle value) + : index(index), value(std::move(value)) {} + + size_t index; + Handle value; + }; + + class Iterator; + + absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + bool Equals(const Value& other) const; void HashValue(absl::HashState state) const; @@ -104,6 +117,27 @@ class ListValue : public Value { internal::TypeInfo TypeId() const; }; +// Abstract class describes an iterator which can iterate over the elements in a +// list. A default implementation is provided by `ListValue::NewIterator`, +// however it is likely not as efficient as providing your own implementation. +class ListValue::Iterator { + public: + using Element = ListValue::Element; + + virtual ~Iterator() = default; + + ABSL_MUST_USE_RESULT virtual bool HasNext() = 0; + + virtual absl::StatusOr Next( + const ListValue::GetContext& context) = 0; + + virtual absl::StatusOr NextIndex( + const ListValue::GetContext& context); + + virtual absl::StatusOr> NextValue( + const ListValue::GetContext& context); +}; + namespace base_internal { ABSL_ATTRIBUTE_WEAK absl::StatusOr> LegacyListValueGet( @@ -139,6 +173,9 @@ class LegacyListValue final : public ListValue, public InlineData { constexpr uintptr_t value() const { return impl_; } + absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: friend class base_internal::ValueHandle; friend class cel::ListValue; @@ -187,6 +224,9 @@ class AbstractListValue : public ListValue, virtual absl::StatusOr> Get(const GetContext& context, size_t index) const = 0; + virtual absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + protected: explicit AbstractListValue(Handle type); diff --git a/base/values/map_value.cc b/base/values/map_value.cc index d2966a0d2..25a71b8a8 100644 --- a/base/values/map_value.cc +++ b/base/values/map_value.cc @@ -19,14 +19,19 @@ #include #include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "base/handle.h" #include "base/internal/data.h" #include "base/types/map_type.h" #include "base/value.h" +#include "base/value_factory.h" #include "base/values/list_value.h" #include "internal/rtti.h" +#include "internal/status_macros.h" namespace cel { @@ -66,14 +71,164 @@ absl::StatusOr> MapValue::ListKeys( return CEL_INTERNAL_MAP_VALUE_DISPATCH(ListKeys, context); } +absl::StatusOr> MapValue::NewIterator( + MemoryManager& memory_manager) const { + return CEL_INTERNAL_MAP_VALUE_DISPATCH(NewIterator, memory_manager); +} + internal::TypeInfo MapValue::TypeId() const { return CEL_INTERNAL_MAP_VALUE_DISPATCH(TypeId); } #undef CEL_INTERNAL_MAP_VALUE_DISPATCH +absl::StatusOr> MapValue::Iterator::NextKey( + const MapValue::GetContext& context) { + CEL_ASSIGN_OR_RETURN(auto entry, Next(context)); + return std::move(entry.key); +} + +absl::StatusOr> MapValue::Iterator::NextValue( + const MapValue::GetContext& context) { + CEL_ASSIGN_OR_RETURN(auto entry, Next(context)); + return std::move(entry.value); +} + namespace base_internal { +namespace { + +class LegacyMapValueIterator final : public MapValue::Iterator { + public: + explicit LegacyMapValueIterator(uintptr_t impl) : impl_(impl) {} + + bool HasNext() override { + if (ABSL_PREDICT_FALSE(!keys_iterator_.has_value())) { + // First call. + return !LegacyMapValueEmpty(impl_); + } + return (*keys_iterator_)->HasNext(); + } + + absl::StatusOr Next(const MapValue::GetContext& context) override { + CEL_RETURN_IF_ERROR(OnNext(context.value_factory())); + CEL_ASSIGN_OR_RETURN( + auto key, + (*keys_iterator_) + ->NextValue(ListValue::GetContext(context.value_factory()))); + CEL_ASSIGN_OR_RETURN( + auto value, LegacyMapValueGet(impl_, context.value_factory(), key)); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + // Something is seriously wrong. The list of keys from the map is not + // consistent with what the map believes is set. + return absl::InternalError( + "inconsistency between list of map keys and map"); + } + return Entry(std::move(key), std::move(value).value()); + } + + absl::StatusOr> NextKey( + const MapValue::GetContext& context) override { + CEL_RETURN_IF_ERROR(OnNext(context.value_factory())); + CEL_ASSIGN_OR_RETURN( + auto key, + (*keys_iterator_) + ->NextValue(ListValue::GetContext(context.value_factory()))); + return key; + } + + absl::StatusOr> NextValue( + const MapValue::GetContext& context) override { + CEL_RETURN_IF_ERROR(OnNext(context.value_factory())); + CEL_ASSIGN_OR_RETURN( + auto key, + (*keys_iterator_) + ->NextValue(ListValue::GetContext(context.value_factory()))); + CEL_ASSIGN_OR_RETURN( + auto value, LegacyMapValueGet(impl_, context.value_factory(), key)); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + // Something is seriously wrong. The list of keys from the map is not + // consistent with what the map believes is set. + return absl::InternalError( + "inconsistency between list of map keys and map"); + } + return std::move(value).value(); + } + + private: + absl::Status OnNext(ValueFactory& value_factory) { + if (ABSL_PREDICT_FALSE(!keys_iterator_.has_value())) { + CEL_ASSIGN_OR_RETURN(keys_, LegacyMapValueListKeys(impl_, value_factory)); + CEL_ASSIGN_OR_RETURN(keys_iterator_, + keys_->NewIterator(value_factory.memory_manager())); + ABSL_CHECK((*keys_iterator_)->HasNext()); // Crash OK + } + return absl::OkStatus(); + } + + const uintptr_t impl_; + Handle keys_; + absl::optional> keys_iterator_; +}; + +class AbstractMapValueIterator final : public MapValue::Iterator { + public: + explicit AbstractMapValueIterator(const AbstractMapValue* value) + : value_(value) {} + + bool HasNext() override { + if (ABSL_PREDICT_FALSE(!keys_iterator_.has_value())) { + // First call. + return !value_->empty(); + } + return (*keys_iterator_)->HasNext(); + } + + absl::StatusOr Next(const MapValue::GetContext& context) override { + CEL_RETURN_IF_ERROR(OnNext(context.value_factory())); + CEL_ASSIGN_OR_RETURN( + auto key, + (*keys_iterator_) + ->NextValue(ListValue::GetContext(context.value_factory()))); + CEL_ASSIGN_OR_RETURN(auto value, value_->Get(context, key)); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + // Something is seriously wrong. The list of keys from the map is not + // consistent with what the map believes is set. + return absl::InternalError( + "inconsistency between list of map keys and map"); + } + return Entry(std::move(key), std::move(value).value()); + } + + absl::StatusOr> NextKey( + const MapValue::GetContext& context) override { + CEL_RETURN_IF_ERROR(OnNext(context.value_factory())); + CEL_ASSIGN_OR_RETURN( + auto key, + (*keys_iterator_) + ->NextValue(ListValue::GetContext(context.value_factory()))); + return key; + } + + private: + absl::Status OnNext(ValueFactory& value_factory) { + if (ABSL_PREDICT_FALSE(!keys_iterator_.has_value())) { + CEL_ASSIGN_OR_RETURN( + keys_, value_->ListKeys(MapValue::ListKeysContext(value_factory))); + CEL_ASSIGN_OR_RETURN(keys_iterator_, + keys_->NewIterator(value_factory.memory_manager())); + ABSL_CHECK((*keys_iterator_)->HasNext()); // Crash OK + } + return absl::OkStatus(); + } + + const AbstractMapValue* const value_; + Handle keys_; + absl::optional> keys_iterator_; +}; + +} // namespace + Handle LegacyMapValue::type() const { return HandleFactory::Make(); } @@ -100,6 +255,11 @@ absl::StatusOr> LegacyMapValue::ListKeys( return LegacyMapValueListKeys(impl_, context.value_factory()); } +absl::StatusOr> LegacyMapValue::NewIterator( + MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, impl_); +} + AbstractMapValue::AbstractMapValue(Handle type) : HeapData(kKind), type_(std::move(type)) { // Ensure `Value*` and `HeapData*` are not thunked. @@ -107,6 +267,11 @@ AbstractMapValue::AbstractMapValue(Handle type) reinterpret_cast(static_cast(this))); } +absl::StatusOr> AbstractMapValue::NewIterator( + MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, this); +} + } // namespace base_internal } // namespace cel diff --git a/base/values/map_value.h b/base/values/map_value.h index ba3ac382c..2f7477385 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "absl/base/attributes.h" #include "absl/log/absl_check.h" @@ -26,6 +27,7 @@ #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/owner.h" #include "base/type.h" #include "base/types/map_type.h" @@ -107,6 +109,19 @@ class MapValue : public Value { absl::StatusOr> ListKeys( const ListKeysContext& context) const; + struct Entry final { + Entry(Handle key, Handle value) + : key(std::move(key)), value(std::move(value)) {} + + Handle key; + Handle value; + }; + + class Iterator; + + absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: friend internal::TypeInfo base_internal::GetMapValueTypeId( const MapValue& map_value); @@ -120,6 +135,26 @@ class MapValue : public Value { internal::TypeInfo TypeId() const; }; +// Abstract class describes an iterator which can iterate over the entries in a +// map. A default implementation is provided by `MapValue::NewIterator`, however +// it is likely not as efficient as providing your own implementation. +class MapValue::Iterator { + public: + using Entry = MapValue::Entry; + + virtual ~Iterator() = default; + + ABSL_MUST_USE_RESULT virtual bool HasNext() = 0; + + virtual absl::StatusOr Next(const MapValue::GetContext& context) = 0; + + virtual absl::StatusOr> NextKey( + const MapValue::GetContext& context); + + virtual absl::StatusOr> NextValue( + const MapValue::GetContext& context); +}; + CEL_INTERNAL_VALUE_DECL(MapValue); namespace base_internal { @@ -166,6 +201,9 @@ class LegacyMapValue final : public MapValue, public InlineData { absl::StatusOr> ListKeys( const ListKeysContext& context) const; + absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + constexpr uintptr_t value() const { return impl_; } private: @@ -220,6 +258,9 @@ class AbstractMapValue : public MapValue, virtual absl::StatusOr> ListKeys( const ListKeysContext& context) const = 0; + virtual absl::StatusOr> NewIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + protected: explicit AbstractMapValue(Handle type); diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index 25399c5b4..d5bf3ae55 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -15,6 +15,7 @@ #include "eval/internal/interop.h" #include +#include #include #include #include @@ -24,7 +25,7 @@ #include "absl/status/status.h" #include "absl/strings/escaping.h" #include "absl/time/time.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_manager.h" #include "base/value.h" #include "base/value_factory.h" @@ -465,6 +466,60 @@ TEST(ValueInterop, LegacyListRoundtrip) { EXPECT_EQ(value.ListOrDie(), legacy_value.ListOrDie()); } +TEST(ValueInterop, LegacyListNewIteratorIndices) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto value = + CelValue::CreateList(google::protobuf::Arena::Create< + google::api::expr::runtime::ContainerBackedListImpl>( + &arena, std::vector{CelValue::CreateInt64(0), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2)})); + ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); + ASSERT_OK_AND_ASSIGN( + auto iterator, modern_value->As().NewIterator(memory_manager)); + std::set actual_indices; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto index, iterator->NextIndex(ListValue::GetContext(value_factory))); + actual_indices.insert(index); + } + EXPECT_THAT(iterator->NextIndex(ListValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_indices = {0, 1, 2}; + EXPECT_EQ(actual_indices, expected_indices); +} + +TEST(ValueInterop, LegacyListNewIteratorValues) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto value = + CelValue::CreateList(google::protobuf::Arena::Create< + google::api::expr::runtime::ContainerBackedListImpl>( + &arena, std::vector{CelValue::CreateInt64(3), + CelValue::CreateInt64(4), + CelValue::CreateInt64(5)})); + ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); + ASSERT_OK_AND_ASSIGN( + auto iterator, modern_value->As().NewIterator(memory_manager)); + std::set actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, iterator->NextValue(ListValue::GetContext(value_factory))); + actual_values.insert(value->As().value()); + } + EXPECT_THAT(iterator->NextValue(ListValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_values = {3, 4, 5}; + EXPECT_EQ(actual_values, expected_values); +} + TEST(ValueInterop, MapFromLegacy) { google::protobuf::Arena arena; extensions::ProtoMemoryManager memory_manager(&arena); @@ -609,6 +664,66 @@ TEST(ValueInterop, LegacyMapRoundtrip) { EXPECT_EQ(value.MapOrDie(), legacy_value.MapOrDie()); } +TEST(ValueInterop, LegacyMapNewIteratorKeys) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto* map_builder = + google::protobuf::Arena::Create(&arena); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("foo"), + CelValue::CreateInt64(1))); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("bar"), + CelValue::CreateInt64(2))); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("baz"), + CelValue::CreateInt64(3))); + auto value = CelValue::CreateMap(map_builder); + ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); + ASSERT_OK_AND_ASSIGN( + auto iterator, modern_value->As().NewIterator(memory_manager)); + std::set actual_keys; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto key, iterator->NextKey(MapValue::GetContext(value_factory))); + actual_keys.insert(key->As().ToString()); + } + EXPECT_THAT(iterator->NextKey(MapValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_keys = {"foo", "bar", "baz"}; + EXPECT_EQ(actual_keys, expected_keys); +} + +TEST(ValueInterop, LegacyMapNewIteratorValues) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto* map_builder = + google::protobuf::Arena::Create(&arena); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("foo"), + CelValue::CreateInt64(1))); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("bar"), + CelValue::CreateInt64(2))); + ASSERT_OK(map_builder->Add(CelValue::CreateStringView("baz"), + CelValue::CreateInt64(3))); + auto value = CelValue::CreateMap(map_builder); + ASSERT_OK_AND_ASSIGN(auto modern_value, FromLegacyValue(&arena, value)); + ASSERT_OK_AND_ASSIGN( + auto iterator, modern_value->As().NewIterator(memory_manager)); + std::set actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, iterator->NextValue(MapValue::GetContext(value_factory))); + actual_values.insert(value->As().value()); + } + EXPECT_THAT(iterator->NextValue(MapValue::GetContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_values = {1, 2, 3}; + EXPECT_EQ(actual_values, expected_values); +} + TEST(ValueInterop, StructFromLegacy) { google::protobuf::Arena arena; extensions::ProtoMemoryManager memory_manager(&arena); From cdd9574e98837497e3d11bfe96b498f1ba2ce028 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 May 2023 17:03:00 +0000 Subject: [PATCH 251/303] Remove deprecated aliases and cleanup artifacts from `:memory` consolidation PiperOrigin-RevId: 530640656 --- base/BUILD | 73 +++---------------- base/allocator.h | 22 ------ base/allocator_test.cc | 40 ---------- base/function_adapter_test.cc | 2 +- base/internal/BUILD | 2 +- base/internal/function_adapter_test.cc | 2 +- base/{memory_manager.cc => memory.cc} | 4 +- base/memory_manager.h | 22 ------ ...{memory_manager_test.cc => memory_test.cc} | 17 ++++- base/type_factory.h | 2 +- base/type_factory_test.cc | 2 +- base/type_provider_test.cc | 2 +- base/type_test.cc | 2 +- base/value_factory.h | 2 +- base/value_factory_test.cc | 2 +- base/values/list_value_builder.h | 2 +- base/values/list_value_builder_test.cc | 2 +- base/values/map_value_builder.h | 2 +- base/values/map_value_builder_test.cc | 2 +- eval/eval/BUILD | 15 ++-- eval/eval/attribute_trail.h | 2 +- eval/eval/attribute_utility.h | 2 +- eval/eval/container_access_step.cc | 2 +- eval/eval/evaluator_core.h | 2 +- eval/eval/select_step.cc | 2 +- eval/internal/BUILD | 6 +- eval/internal/adapter_activation_impl.cc | 2 +- eval/internal/errors.cc | 2 +- eval/internal/errors.h | 2 +- eval/public/BUILD | 4 +- eval/public/cel_value.cc | 2 +- eval/public/cel_value.h | 2 +- eval/public/cel_value_test.cc | 2 +- eval/public/structs/BUILD | 6 +- eval/public/structs/legacy_type_adapter.h | 2 +- .../structs/proto_message_type_adapter.h | 2 +- extensions/protobuf/BUILD | 6 +- extensions/protobuf/internal/BUILD | 2 +- extensions/protobuf/internal/testing.h | 2 +- extensions/protobuf/memory_manager.h | 2 +- extensions/protobuf/struct_type_test.cc | 2 +- extensions/protobuf/struct_value.cc | 2 +- extensions/protobuf/type_provider_test.cc | 2 +- runtime/BUILD | 2 +- runtime/activation_test.cc | 2 +- 45 files changed, 80 insertions(+), 203 deletions(-) delete mode 100644 base/allocator.h delete mode 100644 base/allocator_test.cc rename base/{memory_manager.cc => memory.cc} (99%) delete mode 100644 base/memory_manager.h rename base/{memory_manager_test.cc => memory_test.cc} (76%) diff --git a/base/BUILD b/base/BUILD index 355d26817..bdeeeea53 100644 --- a/base/BUILD +++ b/base/BUILD @@ -19,21 +19,6 @@ package( licenses(["notice"]) -alias( - name = "allocator", - actual = ":memory", -) - -cc_test( - name = "allocator_test", - srcs = ["allocator_test.cc"], - deps = [ - ":allocator", - ":memory_manager", - "//internal:testing", - ], -) - cc_library( name = "attributes", srcs = [ @@ -95,11 +80,9 @@ cc_test( cc_library( name = "memory", - srcs = ["memory_manager.cc"], + srcs = ["memory.cc"], hdrs = [ - "allocator.h", "memory.h", - "memory_manager.h", ], deps = [ ":handle", @@ -116,16 +99,13 @@ cc_library( ], ) -alias( - name = "memory_manager", - actual = ":memory", -) - cc_test( - name = "memory_manager_test", - srcs = ["memory_manager_test.cc"], + name = "memory_test", + srcs = [ + "memory_test.cc", + ], deps = [ - ":memory_manager", + ":memory", "//internal:testing", ], ) @@ -175,7 +155,7 @@ cc_library( deps = [ ":handle", ":kind", - ":memory_manager", + ":memory", "//base/internal:data", "//base/internal:type", "//internal:casts", @@ -197,30 +177,6 @@ cc_library( ], ) -alias( - name = "type_manager", - actual = "type", - deprecation = "Use :type instead", -) - -alias( - name = "type_provider", - actual = "type", - deprecation = "Use :type instead", -) - -alias( - name = "type_registry", - actual = "type", - deprecation = "Use :type instead", -) - -alias( - name = "type_factory", - actual = "type", - deprecation = "Use :type instead", -) - cc_test( name = "type_test", srcs = [ @@ -230,7 +186,7 @@ cc_test( ], deps = [ ":handle", - ":memory_manager", + ":memory", ":type", ":value", "//base/internal:memory_manager_testing", @@ -254,12 +210,11 @@ cc_library( "value_factory.h", ] + glob(["values/*.h"]), deps = [ - ":allocator", ":attributes", ":function_result_set", ":handle", ":kind", - ":memory_manager", + ":memory", ":owner", ":type", "//base/internal:data", @@ -289,12 +244,6 @@ cc_library( ], ) -alias( - name = "value_factory", - actual = "value", - deprecation = "Use :value instead", -) - cc_test( name = "value_test", srcs = [ @@ -302,7 +251,7 @@ cc_test( "value_test.cc", ] + glob(["values/*_test.cc"]), deps = [ - ":memory_manager", + ":memory", ":type", ":value", "//base/internal:memory_manager_testing", @@ -425,7 +374,7 @@ cc_test( ":function_descriptor", ":handle", ":kind", - ":memory_manager", + ":memory", ":type", ":value", "//internal:testing", diff --git a/base/allocator.h b/base/allocator.h deleted file mode 100644 index e75217f9c..000000000 --- a/base/allocator.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ -#define THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ - -// TODO(issues/5): delete - -#include "base/memory.h" - -#endif // THIRD_PARTY_CEL_CPP_BASE_ALLOCATOR_H_ diff --git a/base/allocator_test.cc b/base/allocator_test.cc deleted file mode 100644 index 098417eba..000000000 --- a/base/allocator_test.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "base/allocator.h" - -#include - -#include "base/memory_manager.h" -#include "internal/testing.h" - -namespace cel { -namespace { - -TEST(Allocator, Global) { - std::vector> vector( - Allocator{MemoryManager::Global()}); - vector.push_back(0); - vector.resize(64, 0); -} - -TEST(Allocator, Arena) { - auto memory_manager = ArenaMemoryManager::Default(); - std::vector> vector(Allocator{*memory_manager}); - vector.push_back(0); - vector.resize(64, 0); -} - -} // namespace -} // namespace cel diff --git a/base/function_adapter_test.cc b/base/function_adapter_test.cc index f5f75f7bf..124e18999 100644 --- a/base/function_adapter_test.cc +++ b/base/function_adapter_test.cc @@ -25,7 +25,7 @@ #include "base/function_descriptor.h" #include "base/handle.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_provider.h" #include "base/value_factory.h" diff --git a/base/internal/BUILD b/base/internal/BUILD index 6c364a30b..f8bcbccef 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -172,7 +172,7 @@ cc_test( ":function_adapter", "//base:handle", "//base:kind", - "//base:memory_manager", + "//base:memory", "//base:type", "//base:value", "//internal:testing", diff --git a/base/internal/function_adapter_test.cc b/base/internal/function_adapter_test.cc index 758a85d6c..ea18006b8 100644 --- a/base/internal/function_adapter_test.cc +++ b/base/internal/function_adapter_test.cc @@ -22,7 +22,7 @@ #include "absl/time/time.h" #include "base/handle.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/type_provider.h" diff --git a/base/memory_manager.cc b/base/memory.cc similarity index 99% rename from base/memory_manager.cc rename to base/memory.cc index 13b4ff975..82992965d 100644 --- a/base/memory_manager.cc +++ b/base/memory.cc @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "base/memory_manager.h" +#include "base/memory.h" #ifndef _WIN32 #include diff --git a/base/memory_manager.h b/base/memory_manager.h deleted file mode 100644 index f73e9f84b..000000000 --- a/base/memory_manager.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ - -// TODO(issues/5): delete - -#include "base/memory.h" - -#endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ diff --git a/base/memory_manager_test.cc b/base/memory_test.cc similarity index 76% rename from base/memory_manager_test.cc rename to base/memory_test.cc index c31cca027..5351a2a36 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_test.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "base/memory_manager.h" +#include "base/memory.h" #include +#include #include "internal/testing.h" @@ -42,5 +43,19 @@ TEST(ArenaMemoryManager, NotTriviallyDestuctible) { } } +TEST(Allocator, Global) { + std::vector> vector( + Allocator{MemoryManager::Global()}); + vector.push_back(0); + vector.resize(64, 0); +} + +TEST(Allocator, Arena) { + auto memory_manager = ArenaMemoryManager::Default(); + std::vector> vector(Allocator{*memory_manager}); + vector.push_back(0); + vector.resize(64, 0); +} + } // namespace } // namespace cel diff --git a/base/type_factory.h b/base/type_factory.h index 5406ada87..cdbdc7e3c 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -24,7 +24,7 @@ #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/types/any_type.h" #include "base/types/bool_type.h" #include "base/types/bytes_type.h" diff --git a/base/type_factory_test.cc b/base/type_factory_test.cc index 4f8e1a3d3..ff9901e87 100644 --- a/base/type_factory_test.cc +++ b/base/type_factory_test.cc @@ -15,7 +15,7 @@ #include "base/type_factory.h" #include "absl/status/status.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "internal/testing.h" namespace cel { diff --git a/base/type_provider_test.cc b/base/type_provider_test.cc index ec65eb308..3e99dc390 100644 --- a/base/type_provider_test.cc +++ b/base/type_provider_test.cc @@ -17,7 +17,7 @@ #include #include "base/internal/memory_manager_testing.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "internal/testing.h" diff --git a/base/type_test.cc b/base/type_test.cc index ac3b7a10a..d583e7286 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -22,7 +22,7 @@ #include "absl/status/status.h" #include "base/handle.h" #include "base/internal/memory_manager_testing.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/value.h" diff --git a/base/value_factory.h b/base/value_factory.h index 63fdee57e..5bcc170a2 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -31,7 +31,7 @@ #include "base/attribute_set.h" #include "base/function_result_set.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/owner.h" #include "base/type_manager.h" #include "base/value.h" diff --git a/base/value_factory_test.cc b/base/value_factory_test.cc index 36d7ac285..ce0b040c7 100644 --- a/base/value_factory_test.cc +++ b/base/value_factory_test.cc @@ -15,7 +15,7 @@ #include "base/value_factory.h" #include "absl/status/status.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "internal/testing.h" namespace cel { diff --git a/base/values/list_value_builder.h b/base/values/list_value_builder.h index c773d3dab..b655d41a7 100644 --- a/base/values/list_value_builder.h +++ b/base/values/list_value_builder.h @@ -23,7 +23,7 @@ #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "base/allocator.h" +#include "base/memory.h" #include "base/value_factory.h" #include "base/values/list_value.h" diff --git a/base/values/list_value_builder_test.cc b/base/values/list_value_builder_test.cc index ea5208b36..06afbf0e9 100644 --- a/base/values/list_value_builder_test.cc +++ b/base/values/list_value_builder_test.cc @@ -15,7 +15,7 @@ #include "base/values/list_value_builder.h" #include "absl/time/time.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_provider.h" #include "internal/testing.h" diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index fe30d9349..5aa788254 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -25,7 +25,7 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/status/statusor.h" -#include "base/allocator.h" +#include "base/memory.h" #include "base/value_factory.h" #include "base/values/list_value_builder.h" #include "base/values/map_value.h" diff --git a/base/values/map_value_builder_test.cc b/base/values/map_value_builder_test.cc index f710ef098..a325a1655 100644 --- a/base/values/map_value_builder_test.cc +++ b/base/values/map_value_builder_test.cc @@ -15,7 +15,7 @@ #include "base/values/map_value_builder.h" #include "absl/time/time.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_provider.h" #include "internal/testing.h" diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 33a3a8b19..790d5b183 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -20,7 +20,7 @@ cc_library( ":evaluator_stack", "//base:ast_internal", "//base:handle", - "//base:memory_manager", + "//base:memory", "//base:type", "//base:value", "//eval/internal:adapter_activation_impl", @@ -71,11 +71,8 @@ cc_test( ], deps = [ ":evaluator_stack", - "//base:type_factory", - "//base:type_manager", - "//base:type_provider", + "//base:type", "//base:value", - "//base:value_factory", "//extensions/protobuf:memory_manager", "//internal:testing", ], @@ -123,7 +120,7 @@ cc_library( ":expression_step_base", "//base:attributes", "//base:kind", - "//base:memory_manager", + "//base:memory", "//base:value", "//eval/internal:errors", "//eval/internal:interop", @@ -230,7 +227,7 @@ cc_library( ":expression_step_base", "//base:ast_internal", "//base:handle", - "//base:memory_manager", + "//base:memory", "//base:type", "//base:value", "//eval/internal:errors", @@ -668,7 +665,7 @@ cc_library( srcs = ["attribute_trail.cc"], hdrs = ["attribute_trail.h"], deps = [ - "//base:memory_manager", + "//base:memory", "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", @@ -709,7 +706,7 @@ cc_library( "//base:function_result", "//base:function_result_set", "//base:handle", - "//base:memory_manager", + "//base:memory", "//base:value", "//eval/public:unknown_set", "//extensions/protobuf:memory_manager", diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index 4afbc10eb..8e485aa03 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -9,7 +9,7 @@ #include "google/protobuf/arena.h" #include "absl/types/optional.h" #include "absl/utility/utility.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 2dc8260fa..d09946c89 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -11,7 +11,7 @@ #include "base/function_result.h" #include "base/function_result_set.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/value.h" #include "eval/eval/attribute_trail.h" #include "eval/public/unknown_set.h" diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 95eeefd6b..e1b3edd66 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -11,7 +11,7 @@ #include "absl/types/span.h" #include "base/attribute.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/value.h" #include "base/values/bool_value.h" #include "base/values/double_value.h" diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 43e812f2d..3045c0419 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -23,7 +23,7 @@ #include "absl/types/optional.h" #include "base/ast_internal.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_manager.h" #include "base/value.h" #include "base/value_factory.h" diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 82fc569c3..3d835f0eb 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -9,7 +9,7 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_manager.h" #include "base/value_factory.h" #include "base/values/error_value.h" diff --git a/eval/internal/BUILD b/eval/internal/BUILD index 862936bb8..0abc3ca8e 100644 --- a/eval/internal/BUILD +++ b/eval/internal/BUILD @@ -52,7 +52,7 @@ cc_test( deps = [ ":errors", ":interop", - "//base:memory_manager", + "//base:memory", "//base:type", "//base:value", "//eval/public:cel_value", @@ -76,7 +76,7 @@ cc_library( srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ - "//base:memory_manager", + "//base:memory", "//extensions/protobuf:memory_manager", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -92,7 +92,7 @@ cc_library( ":interop", "//base:attributes", "//base:handle", - "//base:memory_manager", + "//base:memory", "//base:value", "//eval/public:base_activation", "//eval/public:cel_value", diff --git a/eval/internal/adapter_activation_impl.cc b/eval/internal/adapter_activation_impl.cc index 6faef26f6..e8055304f 100644 --- a/eval/internal/adapter_activation_impl.cc +++ b/eval/internal/adapter_activation_impl.cc @@ -18,7 +18,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/internal/interop.h" #include "eval/public/cel_value.h" #include "extensions/protobuf/memory_manager.h" diff --git a/eval/internal/errors.cc b/eval/internal/errors.cc index 89784b959..60f170457 100644 --- a/eval/internal/errors.cc +++ b/eval/internal/errors.cc @@ -17,7 +17,7 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "extensions/protobuf/memory_manager.h" namespace cel::interop_internal { diff --git a/eval/internal/errors.h b/eval/internal/errors.h index ef3b41ecb..aebd71522 100644 --- a/eval/internal/errors.h +++ b/eval/internal/errors.h @@ -17,7 +17,7 @@ #include "google/protobuf/arena.h" #include "absl/status/status.h" -#include "base/memory_manager.h" +#include "base/memory.h" namespace cel::interop_internal { diff --git a/eval/public/BUILD b/eval/public/BUILD index c761311aa..520cd7ccc 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -74,7 +74,7 @@ cc_library( ":message_wrapper", ":unknown_set", "//base:kind", - "//base:memory_manager", + "//base:memory", "//eval/internal:errors", "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", @@ -701,7 +701,7 @@ cc_test( ":cel_value_internal", ":unknown_attribute_set", ":unknown_set", - "//base:memory_manager", + "//base:memory", "//eval/internal:errors", "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:trivial_legacy_type_info", diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index e03c2ee68..2086c47b8 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -11,7 +11,7 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/internal/errors.h" #include "eval/public/cel_value_internal.h" #include "eval/public/structs/legacy_type_info_apis.h" diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 57fff2ca9..d824c5f11 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -34,7 +34,7 @@ #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/public/cel_value_internal.h" #include "eval/public/message_wrapper.h" #include "eval/public/unknown_set.h" diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 88ea1aa35..18ebb547d 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -6,7 +6,7 @@ #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/internal/errors.h" #include "eval/public/cel_value_internal.h" #include "eval/public/structs/legacy_type_info_apis.h" diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 91d388c70..b8edab5eb 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -194,7 +194,7 @@ cc_library( deps = [ ":legacy_any_packing", ":legacy_type_adapter", - "//base:type_provider", + "//base:type", "@com_google_absl//absl/types:optional", ], ) @@ -203,7 +203,7 @@ cc_library( name = "legacy_type_adapter", hdrs = ["legacy_type_adapter.h"], deps = [ - "//base:memory_manager", + "//base:memory", "//eval/public:cel_options", "//eval/public:cel_value", "@com_google_absl//absl/status", @@ -235,7 +235,7 @@ cc_library( ":field_access_impl", ":legacy_type_adapter", ":legacy_type_info_apis", - "//base:memory_manager", + "//base:memory", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:message_wrapper", diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 1ddc9536e..6eee0941e 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -19,7 +19,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ #include "absl/status/status.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index d56540e3e..3447da20b 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -19,7 +19,7 @@ #include "google/protobuf/message.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_adapter.h" diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 9cfcc7d87..3d75e5532 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -24,7 +24,7 @@ cc_library( srcs = ["memory_manager.cc"], hdrs = ["memory_manager.h"], deps = [ - "//base:memory_manager", + "//base:memory", "//internal:casts", "//internal:rtti", "@com_google_absl//absl/base:core_headers", @@ -119,7 +119,7 @@ cc_test( deps = [ ":type", "//base:kind", - "//base:memory_manager", + "//base:memory", "//base:type", "//base/internal:memory_manager_testing", "//base/testing:type_matchers", @@ -146,9 +146,9 @@ cc_library( deps = [ ":memory_manager", ":type", - "//base:allocator", "//base:handle", "//base:kind", + "//base:memory", "//base:owner", "//base:type", "//base:value", diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index bd10c452e..475195c52 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -31,7 +31,7 @@ cc_library( testonly = True, hdrs = ["testing.h"], deps = [ - "//base:memory_manager", + "//base:memory", "//base/internal:memory_manager_testing", "//extensions/protobuf:memory_manager", "//internal:testing", diff --git a/extensions/protobuf/internal/testing.h b/extensions/protobuf/internal/testing.h index 443891eb3..1e767a19a 100644 --- a/extensions/protobuf/internal/testing.h +++ b/extensions/protobuf/internal/testing.h @@ -19,7 +19,7 @@ #include "absl/types/optional.h" #include "base/internal/memory_manager_testing.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" #include "google/protobuf/arena.h" diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h index bc15d6413..1a9b03e42 100644 --- a/extensions/protobuf/memory_manager.h +++ b/extensions/protobuf/memory_manager.h @@ -20,7 +20,7 @@ #include "google/protobuf/arena.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "internal/casts.h" #include "internal/rtti.h" diff --git a/extensions/protobuf/struct_type_test.cc b/extensions/protobuf/struct_type_test.cc index e0173f526..9dbbdd2a5 100644 --- a/extensions/protobuf/struct_type_test.cc +++ b/extensions/protobuf/struct_type_test.cc @@ -15,7 +15,7 @@ #include "extensions/protobuf/struct_type.h" #include "google/protobuf/type.pb.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/types/list_type.h" diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 04466a60b..75b524247 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -36,8 +36,8 @@ #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" -#include "base/allocator.h" #include "base/handle.h" +#include "base/memory.h" #include "base/types/struct_type.h" #include "base/value.h" #include "base/value_factory.h" diff --git a/extensions/protobuf/type_provider_test.cc b/extensions/protobuf/type_provider_test.cc index 22fb4ce1b..9b862f89a 100644 --- a/extensions/protobuf/type_provider_test.cc +++ b/extensions/protobuf/type_provider_test.cc @@ -16,7 +16,7 @@ #include "google/protobuf/type.pb.h" #include "base/kind.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "extensions/protobuf/enum_type.h" diff --git a/runtime/BUILD b/runtime/BUILD index 4762f82f0..cfddeefe2 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -84,7 +84,7 @@ cc_test( "//base:function", "//base:function_descriptor", "//base:handle", - "//base:memory_manager", + "//base:memory", "//base:type", "//base:value", "//internal:status_macros", diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index 6081bc3de..95a636800 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -23,7 +23,7 @@ #include "base/function.h" #include "base/function_descriptor.h" #include "base/handle.h" -#include "base/memory_manager.h" +#include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/type_provider.h" From c8852960dc6647ab1a01d1fad97b99b398f1af25 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 May 2023 20:06:06 +0000 Subject: [PATCH 252/303] Explicitly delete copy/move for `MemoryManager` PiperOrigin-RevId: 530691887 --- base/memory.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/base/memory.h b/base/memory.h index 4e49832b4..9fde6cb4a 100644 --- a/base/memory.h +++ b/base/memory.h @@ -153,8 +153,14 @@ class MemoryManager { public: ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager& Global(); + MemoryManager(const MemoryManager&) = delete; + MemoryManager(MemoryManager&&) = delete; + virtual ~MemoryManager() = default; + MemoryManager& operator=(const MemoryManager&) = delete; + MemoryManager& operator=(MemoryManager&&) = delete; + private: friend class GlobalMemoryManager; friend class ArenaMemoryManager; From fc4e4c1b8d8a22928c33cfbeb9dfe22af059ad41 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 9 May 2023 23:24:19 +0000 Subject: [PATCH 253/303] Add support for iterating over struct fields PiperOrigin-RevId: 530742960 --- base/BUILD | 1 + base/type_test.cc | 8 + base/types/struct_type.cc | 83 +++++++++++ base/types/struct_type.h | 123 +++++++++++---- base/value_test.cc | 16 ++ base/values/struct_value.cc | 80 ++++++++++ base/values/struct_value.h | 50 +++++++ eval/eval/select_step_test.cc | 2 + eval/internal/interop.cc | 21 +++ eval/internal/interop_test.cc | 56 +++++++ .../portable_cel_expr_builder_factory_test.cc | 11 ++ eval/public/structs/legacy_type_adapter.h | 5 + .../structs/legacy_type_adapter_test.cc | 5 + .../structs/proto_message_type_adapter.cc | 28 ++++ .../structs/proto_message_type_adapter.h | 5 + extensions/protobuf/BUILD | 1 + extensions/protobuf/struct_type.cc | 57 +++++++ extensions/protobuf/struct_type.h | 8 +- extensions/protobuf/struct_type_test.cc | 140 +++++++++++++++--- extensions/protobuf/struct_value.cc | 64 ++++++++ extensions/protobuf/struct_value.h | 10 ++ extensions/protobuf/struct_value_test.cc | 93 ++++++++++++ internal/BUILD | 5 + internal/overloaded.h | 30 ++++ 24 files changed, 851 insertions(+), 51 deletions(-) create mode 100644 internal/overloaded.h diff --git a/base/BUILD b/base/BUILD index bdeeeea53..a050068c4 100644 --- a/base/BUILD +++ b/base/BUILD @@ -160,6 +160,7 @@ cc_library( "//base/internal:type", "//internal:casts", "//internal:no_destructor", + "//internal:overloaded", "//internal:rtti", "//internal:status_macros", "@com_google_absl//absl/base", diff --git a/base/type_test.cc b/base/type_test.cc index d583e7286..56e244831 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -96,6 +96,14 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { public: absl::string_view name() const override { return "test_struct.TestStruct"; } + size_t field_count() const override { return 4; } + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const override { + return absl::UnimplementedError( + "StructType::NewFieldIterator() is unimplemented"); + } + protected: absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const override { diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc index 874139518..9afcef8d2 100644 --- a/base/types/struct_type.cc +++ b/base/types/struct_type.cc @@ -15,17 +15,57 @@ #include "base/types/struct_type.h" #include +#include #include "absl/base/macros.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "internal/overloaded.h" +#include "internal/status_macros.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(StructType); +bool operator<(const StructType::FieldId& lhs, const StructType::FieldId& rhs) { + return absl::visit( + internal::Overloaded{ + [&rhs](absl::string_view lhs_name) { + return absl::visit( + internal::Overloaded{// (absl::string_view, absl::string_view) + [lhs_name](absl::string_view rhs_name) { + return lhs_name < rhs_name; + }, + // (absl::string_view, int64_t) + [](int64_t rhs_number) { return false; }}, + rhs.data_); + }, + [&rhs](int64_t lhs_number) { + return absl::visit( + internal::Overloaded{ + // (int64_t, absl::string_view) + [](absl::string_view rhs_name) { return true; }, + // (int64_t, int64_t) + [lhs_number](int64_t rhs_number) { + return lhs_number < rhs_number; + }, + }, + rhs.data_); + }}, + lhs.data_); +} + +std::string StructType::FieldId::DebugString() const { + return absl::visit( + internal::Overloaded{ + [](absl::string_view name) { return std::string(name); }, + [](int64_t number) { return absl::StrCat(number); }}, + data_); +} + #define CEL_INTERNAL_STRUCT_TYPE_DISPATCH(method, ...) \ base_internal::Metadata::IsStoredInline(*this) \ ? static_cast(*this).method( \ @@ -37,6 +77,10 @@ absl::string_view StructType::name() const { return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(name); } +size_t StructType::field_count() const { + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(field_count); +} + std::string StructType::DebugString() const { return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(DebugString); } @@ -56,6 +100,11 @@ absl::StatusOr> StructType::FindFieldByNumber( number); } +absl::StatusOr> +StructType::NewFieldIterator(MemoryManager& memory_manager) const { + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(NewFieldIterator, memory_manager); +} + #undef CEL_INTERNAL_STRUCT_TYPE_DISPATCH struct StructType::FindFieldVisitor final { @@ -77,12 +126,46 @@ absl::StatusOr> StructType::FindField( return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); } +absl::StatusOr StructType::FieldIterator::NextId( + TypeManager& type_manager) { + CEL_ASSIGN_OR_RETURN(auto name, NextName(type_manager)); + return FieldId(name); +} + +absl::StatusOr StructType::FieldIterator::NextName( + TypeManager& type_manager) { + CEL_ASSIGN_OR_RETURN(auto field, Next(type_manager)); + return field.name; +} + +absl::StatusOr StructType::FieldIterator::NextNumber( + TypeManager& type_manager) { + CEL_ASSIGN_OR_RETURN(auto field, Next(type_manager)); + return field.number; +} + +absl::StatusOr> StructType::FieldIterator::NextType( + TypeManager& type_manager) { + CEL_ASSIGN_OR_RETURN(auto field, Next(type_manager)); + return std::move(field.type); +} + namespace base_internal { absl::string_view LegacyStructType::name() const { return MessageTypeName(msg_); } +size_t LegacyStructType::field_count() const { + return MessageTypeFieldCount(msg_); +} + +absl::StatusOr> +LegacyStructType::NewFieldIterator(MemoryManager& memory_manager) const { + return absl::UnimplementedError( + "StructType::NewFieldIterator is not supported by legacy struct types"); +} + // Always returns an error. absl::StatusOr> LegacyStructType::FindFieldByName(TypeManager& type_manager, diff --git a/base/types/struct_type.h b/base/types/struct_type.h index f76fd22ad..1b7b4dddb 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -28,6 +28,7 @@ #include "absl/types/variant.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/type.h" #include "internal/rtti.h" @@ -61,6 +62,24 @@ class StructType : public Type { FieldId(const FieldId&) = default; FieldId& operator=(const FieldId&) = default; + std::string DebugString() const; + + friend bool operator==(const FieldId& lhs, const FieldId& rhs) { + return lhs.data_ == rhs.data_; + } + + friend bool operator<(const FieldId& lhs, const FieldId& rhs); + + template + friend H AbslHashValue(H state, const FieldId& id) { + return H::combine(std::move(state), id.data_); + } + + template + friend void AbslStringify(S& sink, const FieldId& id) { + sink.Append(id.DebugString()); + } + private: friend class StructType; friend class StructValue; @@ -81,24 +100,34 @@ class StructType : public Type { Kind kind() const { return kKind; } - absl::string_view name() const; + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; std::string DebugString() const; + size_t field_count() const; + // Find the field definition for the given identifier. If the field does // not exist, an OK status and empty optional is returned. If the field // exists, an OK status and the field is returned. Otherwise an error is // returned. absl::StatusOr> FindField(TypeManager& type_manager, - FieldId id) const; + FieldId id) const + ABSL_ATTRIBUTE_LIFETIME_BOUND; // Called by FindField. absl::StatusOr> FindFieldByName( - TypeManager& type_manager, absl::string_view name) const; + TypeManager& type_manager, + absl::string_view name) const ABSL_ATTRIBUTE_LIFETIME_BOUND; // Called by FindField. absl::StatusOr> FindFieldByNumber( - TypeManager& type_manager, int64_t number) const; + TypeManager& type_manager, + int64_t number) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + class FieldIterator; + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; private: friend internal::TypeInfo base_internal::GetStructTypeTypeId( @@ -119,6 +148,44 @@ class StructType : public Type { internal::TypeInfo TypeId() const; }; +// Field describes a single field in a struct. All fields are valid so long as +// StructType is valid, except Field::type which is managed. +struct StructType::Field final { + explicit Field(absl::string_view name, int64_t number, Handle type, + const void* hint = nullptr) + : name(name), number(number), type(std::move(type)), hint(hint) {} + + // The field name. + absl::string_view name; + // The field number. + int64_t number; + // The field type; + Handle type; + // Some implementation-specific data that can be laundered to the value + // implementation for this type to enable potential optimizations. + const void* hint = nullptr; +}; + +class StructType::FieldIterator { + public: + using FieldId = StructType::FieldId; + using Field = StructType::Field; + + virtual ~FieldIterator() = default; + + ABSL_MUST_USE_RESULT virtual bool HasNext() = 0; + + virtual absl::StatusOr Next(TypeManager& type_manager) = 0; + + virtual absl::StatusOr NextId(TypeManager& type_manager); + + virtual absl::StatusOr NextName(TypeManager& type_manager); + + virtual absl::StatusOr NextNumber(TypeManager& type_manager); + + virtual absl::StatusOr> NextType(TypeManager& type_manager); +}; + namespace base_internal { // In an ideal world we would just make StructType a heap type. Unfortunately we @@ -127,6 +194,7 @@ namespace base_internal { // variant. ABSL_ATTRIBUTE_WEAK absl::string_view MessageTypeName(uintptr_t msg); +ABSL_ATTRIBUTE_WEAK size_t MessageTypeFieldCount(uintptr_t msg); class LegacyStructType final : public StructType, public base_internal::InlineData { @@ -144,19 +212,25 @@ class LegacyStructType final : public StructType, return static_cast(type); } - absl::string_view name() const; + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; // Always returns the same string. std::string DebugString() const { return std::string(name()); } - protected: + size_t field_count() const; + // Always returns an error. absl::StatusOr> FindFieldByName( - TypeManager& type_manager, absl::string_view name) const; + TypeManager& type_manager, + absl::string_view name) const ABSL_ATTRIBUTE_LIFETIME_BOUND; // Always returns an error. absl::StatusOr> FindFieldByNumber( - TypeManager& type_manager, int64_t number) const; + TypeManager& type_manager, + int64_t number) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; private: static constexpr uintptr_t kMetadata = @@ -196,20 +270,27 @@ class AbstractStructType : public StructType, public base_internal::HeapData { return static_cast(type); } - virtual absl::string_view name() const = 0; + virtual absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; virtual std::string DebugString() const { return std::string(name()); } - protected: - AbstractStructType(); + virtual size_t field_count() const = 0; // Called by FindField. virtual absl::StatusOr> FindFieldByName( - TypeManager& type_manager, absl::string_view name) const = 0; + TypeManager& type_manager, + absl::string_view name) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; // Called by FindField. virtual absl::StatusOr> FindFieldByNumber( - TypeManager& type_manager, int64_t number) const = 0; + TypeManager& type_manager, + int64_t number) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + virtual absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + + protected: + AbstractStructType(); private: friend internal::TypeInfo base_internal::GetStructTypeTypeId( @@ -258,22 +339,6 @@ class AbstractStructType : public StructType, public base_internal::HeapData { #define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ CEL_INTERNAL_IMPLEMENT_TYPE(Struct, struct_type) -struct StructType::Field final { - explicit Field(absl::string_view name, int64_t number, Handle type, - const void* hint = nullptr) - : name(name), number(number), type(std::move(type)), hint(hint) {} - - // The field name. - absl::string_view name; - // The field number. - int64_t number; - // The field type; - Handle type; - // Some implementation-specific data that can be laundered to the value - // implementation for this type to enable potential optimizations. - const void* hint = nullptr; -}; - CEL_INTERNAL_TYPE_DECL(StructType); namespace base_internal { diff --git a/base/value_test.cc b/base/value_test.cc index 569daac91..2c37b7e38 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -196,6 +196,14 @@ class TestStructValue final : public CEL_STRUCT_VALUE_CLASS { } } + size_t field_count() const override { return 4; } + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const override { + return absl::UnimplementedError( + "StructValue::NewFieldIterator() is unimplemented"); + } + private: TestStruct value_; @@ -208,6 +216,14 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { public: absl::string_view name() const override { return "test_struct.TestStruct"; } + size_t field_count() const override { return 4; } + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const override { + return absl::UnimplementedError( + "StructType::NewFieldIterator() is unimplemented"); + } + protected: absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const override { diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc index ba6b16a4f..a41693004 100644 --- a/base/values/struct_value.cc +++ b/base/values/struct_value.cc @@ -17,8 +17,10 @@ #include #include #include +#include #include "absl/base/macros.h" +#include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" @@ -28,6 +30,7 @@ #include "base/types/struct_type.h" #include "base/value.h" #include "internal/rtti.h" +#include "internal/status_macros.h" namespace cel { @@ -44,6 +47,10 @@ Handle StructValue::type() const { return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(type); } +size_t StructValue::field_count() const { + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(field_count); +} + std::string StructValue::DebugString() const { return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(DebugString); } @@ -68,6 +75,11 @@ absl::StatusOr StructValue::HasFieldByNumber( return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(HasFieldByNumber, context, number); } +absl::StatusOr> +StructValue::NewFieldIterator(MemoryManager& memory_manager) const { + return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(NewFieldIterator, memory_manager); +} + internal::TypeInfo StructValue::TypeId() const { return CEL_INTERNAL_STRUCT_VALUE_DISPATCH(TypeId); } @@ -110,8 +122,66 @@ absl::StatusOr StructValue::HasField(const HasFieldContext& context, return absl::visit(HasFieldVisitor{*this, context}, field.data_); } +absl::StatusOr StructValue::FieldIterator::NextId( + const StructValue::GetFieldContext& context) { + CEL_ASSIGN_OR_RETURN(auto entry, Next(context)); + return entry.id; +} + +absl::StatusOr> StructValue::FieldIterator::NextValue( + const StructValue::GetFieldContext& context) { + CEL_ASSIGN_OR_RETURN(auto entry, Next(context)); + return std::move(entry.value); +} + namespace base_internal { +namespace { + +class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { + public: + LegacyStructValueFieldIterator(uintptr_t msg, uintptr_t type_info) + : msg_(msg), + type_info_(type_info), + field_names_(MessageValueListFields(msg_, type_info_)) {} + + bool HasNext() override { return index_ < field_names_.size(); } + + absl::StatusOr Next( + const StructValue::GetFieldContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= field_names_.size())) { + return absl::FailedPreconditionError( + "StructValue::FieldIterator::Next() called when " + "StructValue::FieldIterator::HasNext() returns false"); + } + const auto& field_name = field_names_[index_]; + CEL_ASSIGN_OR_RETURN( + auto value, MessageValueGetFieldByName( + msg_, type_info_, context.value_factory(), field_name, + context.unbox_null_wrapper_types())); + ++index_; + return Field(StructValue::FieldId(field_name), std::move(value)); + } + + absl::StatusOr NextId( + const StructValue::GetFieldContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= field_names_.size())) { + return absl::FailedPreconditionError( + "StructValue::FieldIterator::Next() called when " + "StructValue::FieldIterator::HasNext() returns false"); + } + return StructValue::FieldId(field_names_[index_++]); + } + + private: + const uintptr_t msg_; + const uintptr_t type_info_; + const std::vector field_names_; + size_t index_ = 0; +}; + +} // namespace + Handle LegacyStructValue::type() const { if ((msg_ & kMessageWrapperTagMask) == kMessageWrapperTagMask) { // google::protobuf::Message @@ -121,6 +191,10 @@ Handle LegacyStructValue::type() const { return HandleFactory::Make(type_info_); } +size_t LegacyStructValue::field_count() const { + return MessageValueFieldCount(msg_, type_info_); +} + std::string LegacyStructValue::DebugString() const { return type()->DebugString(); } @@ -148,6 +222,12 @@ absl::StatusOr LegacyStructValue::HasFieldByNumber( return MessageValueHasFieldByNumber(msg_, type_info_, number); } +absl::StatusOr> +LegacyStructValue::NewFieldIterator(MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, msg_, + type_info_); +} + AbstractStructValue::AbstractStructValue(Handle type) : StructValue(), base_internal::HeapData(kKind), type_(std::move(type)) { // Ensure `Value*` and `base_internal::HeapData*` are not thunked. diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 596b5e0af..31f56f9af 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "absl/base/attributes.h" #include "absl/hash/hash.h" @@ -27,6 +29,7 @@ #include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/owner.h" #include "base/type.h" #include "base/types/struct_type.h" @@ -62,6 +65,8 @@ class StructValue : public Value { Handle type() const; + size_t field_count() const; + std::string DebugString() const; class GetFieldContext final { @@ -113,6 +118,18 @@ class StructValue : public Value { absl::StatusOr HasFieldByNumber(const HasFieldContext& context, int64_t number) const; + class FieldIterator; + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + struct Field final { + Field(FieldId id, Handle value) : id(id), value(std::move(value)) {} + + FieldId id; + Handle value; + }; + private: struct GetFieldVisitor; struct HasFieldVisitor; @@ -131,6 +148,25 @@ class StructValue : public Value { internal::TypeInfo TypeId() const; }; +class StructValue::FieldIterator { + public: + using Field = StructValue::Field; + using FieldId = StructValue::FieldId; + + virtual ~FieldIterator() = default; + + ABSL_MUST_USE_RESULT virtual bool HasNext() = 0; + + virtual absl::StatusOr Next( + const StructValue::GetFieldContext& context) = 0; + + virtual absl::StatusOr NextId( + const StructValue::GetFieldContext& context); + + virtual absl::StatusOr> NextValue( + const StructValue::GetFieldContext& context); +}; + CEL_INTERNAL_VALUE_DECL(StructValue); namespace base_internal { @@ -145,6 +181,10 @@ ABSL_ATTRIBUTE_WEAK void MessageValueHash(uintptr_t msg, uintptr_t type_info, ABSL_ATTRIBUTE_WEAK bool MessageValueEquals(uintptr_t lhs_msg, uintptr_t lhs_type_info, const Value& rhs); +ABSL_ATTRIBUTE_WEAK size_t MessageValueFieldCount(uintptr_t msg, + uintptr_t type_info); +ABSL_ATTRIBUTE_WEAK std::vector MessageValueListFields( + uintptr_t msg, uintptr_t type_info); ABSL_ATTRIBUTE_WEAK absl::StatusOr MessageValueHasFieldByNumber( uintptr_t msg, uintptr_t type_info, int64_t number); ABSL_ATTRIBUTE_WEAK absl::StatusOr MessageValueHasFieldByName( @@ -175,6 +215,8 @@ class LegacyStructValue final : public StructValue, public InlineData { std::string DebugString() const; + size_t field_count() const; + absl::StatusOr> GetFieldByName(const GetFieldContext& context, absl::string_view name) const; @@ -187,6 +229,9 @@ class LegacyStructValue final : public StructValue, public InlineData { absl::StatusOr HasFieldByNumber(const HasFieldContext& context, int64_t number) const; + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: struct GetFieldVisitor; struct HasFieldVisitor; @@ -247,6 +292,8 @@ class AbstractStructValue : public StructValue, const Handle& type() const { return type_; } + virtual size_t field_count() const = 0; + virtual std::string DebugString() const = 0; virtual absl::StatusOr> GetFieldByName( @@ -261,6 +308,9 @@ class AbstractStructValue : public StructValue, virtual absl::StatusOr HasFieldByNumber(const HasFieldContext& context, int64_t number) const = 0; + virtual absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + protected: explicit AbstractStructValue(Handle type); diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 93eb5afe3..be39db78b 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -63,6 +63,8 @@ class MockAccessor : public LegacyTypeAccessApis, public LegacyTypeInfoApis { (const CelValue::MessageWrapper& instance), (const override)); MOCK_METHOD(std::string, DebugString, (const CelValue::MessageWrapper& instance), (const override)); + MOCK_METHOD(std::vector, ListFields, + (const CelValue::MessageWrapper& value), (const override)); const LegacyTypeAccessApis* GetAccessApis( const CelValue::MessageWrapper& instance) const override { return this; diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index e0879e27a..a126c322e 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -701,6 +701,27 @@ bool MessageValueEquals(uintptr_t lhs_msg, uintptr_t lhs_type_info, static_cast(rhs))); } +size_t MessageValueFieldCount(uintptr_t msg, uintptr_t type_info) { + auto message_wrapper = MessageWrapperAccess::Make(msg, type_info); + if (message_wrapper.message_ptr() == nullptr) { + return 0; + } + const LegacyTypeAccessApis* access_api = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + return access_api->ListFields(message_wrapper).size(); +} + +std::vector MessageValueListFields(uintptr_t msg, + uintptr_t type_info) { + auto message_wrapper = MessageWrapperAccess::Make(msg, type_info); + if (message_wrapper.message_ptr() == nullptr) { + return std::vector{}; + } + const LegacyTypeAccessApis* access_api = + message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); + return access_api->ListFields(message_wrapper); +} + absl::StatusOr MessageValueHasFieldByNumber(uintptr_t msg, uintptr_t type_info, int64_t number) { diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index d5bf3ae55..3f65416c3 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -830,6 +830,62 @@ TEST(ValueInterop, LegacyStructEquality) { EXPECT_EQ(lhs_value, rhs_value); } +TEST(ValueInterop, LegacyStructNewFieldIteratorIds) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + google::protobuf::Api api; + api.set_name("foo"); + api.set_version("bar"); + ASSERT_OK_AND_ASSIGN( + auto value, + FromLegacyValue(&arena, CelProtoWrapper::CreateMessage(&api, &arena))); + EXPECT_EQ(value->As().field_count(), 2); + ASSERT_OK_AND_ASSIGN( + auto iterator, value->As().NewFieldIterator(memory_manager)); + std::set actual_ids; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto id, iterator->NextId(StructValue::GetFieldContext(value_factory))); + actual_ids.insert(id); + } + EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_ids = {StructType::FieldId("name"), + StructType::FieldId("version")}; + EXPECT_EQ(actual_ids, expected_ids); +} + +TEST(ValueInterop, LegacyStructNewFieldIteratorValues) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + google::protobuf::Api api; + api.set_name("foo"); + api.set_version("bar"); + ASSERT_OK_AND_ASSIGN( + auto value, + FromLegacyValue(&arena, CelProtoWrapper::CreateMessage(&api, &arena))); + EXPECT_EQ(value->As().field_count(), 2); + ASSERT_OK_AND_ASSIGN( + auto iterator, value->As().NewFieldIterator(memory_manager)); + std::set actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, + iterator->NextValue(StructValue::GetFieldContext(value_factory))); + actual_values.insert(value->As().ToString()); + } + EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_values = {"bar", "foo"}; + EXPECT_EQ(actual_values, expected_values); +} + TEST(ValueInterop, UnknownFromLegacy) { AttributeSet attributes({Attribute("foo")}); FunctionResultSet function_results( diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index 9d1857043..cd742f69f 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -309,6 +310,16 @@ class DemoTestMessage : public LegacyTypeMutationApis, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) const override; + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + std::vector fields; + fields.reserve(fields_.size()); + for (const auto& field : fields_) { + fields.emplace_back(field.first); + } + return fields; + } + private: using Field = ProtoField; const DemoTypeProvider& owning_provider_; diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 6eee0941e..a21e6c795 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -18,6 +18,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ +#include + #include "absl/status/status.h" #include "base/memory.h" #include "eval/public/cel_options.h" @@ -89,6 +91,9 @@ class LegacyTypeAccessApis { const CelValue::MessageWrapper&) const { return false; } + + virtual std::vector ListFields( + const CelValue::MessageWrapper& instance) const = 0; }; // Type information about a legacy Struct type. diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index 726a32342..81411f0c0 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -41,6 +41,11 @@ class TestAccessApiImpl : public LegacyTypeAccessApis { cel::MemoryManager& memory_manager) const override { return absl::UnimplementedError("Not implemented"); } + + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + return std::vector(); + } }; TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 556108179..d5ecfe25c 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -152,6 +152,24 @@ absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, return result; } +std::vector ListFieldsImpl( + const CelValue::MessageWrapper& instance) { + if (instance.message_ptr() == nullptr) { + return std::vector(); + } + const auto* message = + cel::internal::down_cast(instance.message_ptr()); + const auto* reflect = message->GetReflection(); + std::vector fields; + reflect->ListFields(*message, &fields); + std::vector field_names; + field_names.reserve(fields.size()); + for (const auto* field : fields) { + field_names.emplace_back(field->name()); + } + return field_names; +} + class DucktypedMessageAdapter : public LegacyTypeAccessApis, public LegacyTypeMutationApis, public LegacyTypeInfoApis { @@ -255,6 +273,11 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, .SetField(field_name, value, memory_manager, instance); } + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override { + return ListFieldsImpl(instance); + } + const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const override { return this; @@ -431,6 +454,11 @@ bool ProtoMessageTypeAdapter::IsEqualTo( return ProtoEquals(**lhs, **rhs); } +std::vector ProtoMessageTypeAdapter::ListFields( + const CelValue::MessageWrapper& instance) const { + return ListFieldsImpl(instance); +} + const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance() { return DucktypedMessageAdapter::GetSingleton(); } diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 3447da20b..12ba4ae0e 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -15,6 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ +#include + #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "absl/status/status.h" @@ -62,6 +64,9 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, bool IsEqualTo(const CelValue::MessageWrapper& instance, const CelValue::MessageWrapper& other_instance) const override; + std::vector ListFields( + const CelValue::MessageWrapper& instance) const override; + private: // Helper for standardizing error messages for SetField operation. absl::Status ValidateSetFieldOp(bool assertion, absl::string_view field, diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 3d75e5532..d4e12702d 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -125,6 +125,7 @@ cc_test( "//base/testing:type_matchers", "//extensions/protobuf/internal:testing", "//internal:testing", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/extensions/protobuf/struct_type.cc b/extensions/protobuf/struct_type.cc index 00cc43a87..8adc6c12a 100644 --- a/extensions/protobuf/struct_type.cc +++ b/extensions/protobuf/struct_type.cc @@ -123,6 +123,63 @@ absl::StatusOr> FieldDescriptorToType( } // namespace +namespace { + +class ProtoStructTypeFieldIterator final : public StructType::FieldIterator { + public: + explicit ProtoStructTypeFieldIterator(const google::protobuf::Descriptor& descriptor) + : descriptor_(descriptor) {} + + bool HasNext() override { return index_ < descriptor_.field_count(); } + + absl::StatusOr Next(TypeManager& type_manager) override { + if (ABSL_PREDICT_FALSE(index_ >= descriptor_.field_count())) { + return absl::FailedPreconditionError( + "StructType::FieldIterator::Next() called when " + "StructType::FieldIterator::HasNext() returns false"); + } + const auto* field = descriptor_.field(index_); + CEL_ASSIGN_OR_RETURN(auto type, FieldDescriptorToType(type_manager, field)); + ++index_; + return StructType::Field(field->name(), field->number(), std::move(type), + field); + } + + absl::StatusOr NextName( + TypeManager& type_manager) override { + if (ABSL_PREDICT_FALSE(index_ >= descriptor_.field_count())) { + return absl::FailedPreconditionError( + "StructType::FieldIterator::Next() called when " + "StructType::FieldIterator::HasNext() returns false"); + } + return descriptor_.field(index_++)->name(); + } + + absl::StatusOr NextNumber(TypeManager& type_manager) override { + if (ABSL_PREDICT_FALSE(index_ >= descriptor_.field_count())) { + return absl::FailedPreconditionError( + "StructType::FieldIterator::Next() called when " + "StructType::FieldIterator::HasNext() returns false"); + } + return descriptor_.field(index_++)->number(); + } + + private: + const google::protobuf::Descriptor& descriptor_; + int index_ = 0; +}; + +} // namespace + +size_t ProtoStructType::field_count() const { + return descriptor().field_count(); +} + +absl::StatusOr> +ProtoStructType::NewFieldIterator(MemoryManager& memory_manager) const { + return MakeUnique(memory_manager, descriptor()); +} + absl::StatusOr> ProtoStructType::FindFieldByName(TypeManager& type_manager, absl::string_view name) const { diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index b5bc795d3..e4a95a41e 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -55,9 +55,11 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { absl::string_view name() const override { return descriptor().full_name(); } - const google::protobuf::Descriptor& descriptor() const { return *descriptor_; } + size_t field_count() const override; + + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const override; - protected: // Called by FindField. absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const override; @@ -66,6 +68,8 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { absl::StatusOr> FindFieldByNumber( TypeManager& type_manager, int64_t number) const override; + const google::protobuf::Descriptor& descriptor() const { return *descriptor_; } + private: friend class ProtoType; friend class ProtoTypeProvider; diff --git a/extensions/protobuf/struct_type_test.cc b/extensions/protobuf/struct_type_test.cc index 9dbbdd2a5..27f82c4dd 100644 --- a/extensions/protobuf/struct_type_test.cc +++ b/extensions/protobuf/struct_type_test.cc @@ -14,12 +14,19 @@ #include "extensions/protobuf/struct_type.h" +#include +#include + #include "google/protobuf/type.pb.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "base/internal/memory_manager_testing.h" #include "base/memory.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/types/list_type.h" #include "base/types/map_type.h" +#include "extensions/protobuf/internal/testing.h" #include "extensions/protobuf/type.h" #include "extensions/protobuf/type_provider.h" #include "internal/testing.h" @@ -28,10 +35,14 @@ namespace cel::extensions { namespace { +using cel::internal::StatusIs; + using TestAllTypes = google::api::expr::test::v1::proto3::TestAllTypes; -TEST(ProtoStructType, CreateStatically) { - TypeFactory type_factory(MemoryManager::Global()); +using ProtoStructTypeTest = ProtoTest<>; + +TEST_P(ProtoStructTypeTest, CreateStatically) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -43,8 +54,8 @@ TEST(ProtoStructType, CreateStatically) { EXPECT_EQ(&type->descriptor(), google::protobuf::Field::descriptor()); } -TEST(ProtoStructType, CreateDynamically) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, CreateDynamically) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -58,8 +69,8 @@ TEST(ProtoStructType, CreateDynamically) { google::protobuf::Field::descriptor()); } -TEST(ProtoStructType, FindFieldByName) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, FindFieldByName) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -73,8 +84,8 @@ TEST(ProtoStructType, FindFieldByName) { EXPECT_EQ(field->type, type_factory.GetStringType()); } -TEST(ProtoStructType, FindFieldByNumber) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, FindFieldByNumber) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -87,8 +98,8 @@ TEST(ProtoStructType, FindFieldByNumber) { EXPECT_EQ(field->type, type_factory.GetStringType()); } -TEST(ProtoStructType, EnumField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, EnumField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -101,8 +112,8 @@ TEST(ProtoStructType, EnumField) { EXPECT_EQ(field->type->name(), "google.protobuf.Field.Cardinality"); } -TEST(ProtoStructType, BoolField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, BoolField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -113,8 +124,8 @@ TEST(ProtoStructType, BoolField) { EXPECT_EQ(field->type, type_factory.GetBoolType()); } -TEST(ProtoStructType, IntField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, IntField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -126,8 +137,8 @@ TEST(ProtoStructType, IntField) { EXPECT_EQ(field->type, type_factory.GetIntType()); } -TEST(ProtoStructType, StringListField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, StringListField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -140,8 +151,8 @@ TEST(ProtoStructType, StringListField) { type_factory.GetStringType()); } -TEST(ProtoStructType, StructListField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, StructListField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( @@ -155,8 +166,8 @@ TEST(ProtoStructType, StructListField) { "google.protobuf.Option"); } -TEST(ProtoStructType, MapField) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ProtoStructTypeTest, MapField) { + TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN(auto type, @@ -170,5 +181,94 @@ TEST(ProtoStructType, MapField) { EXPECT_EQ(field->type.As()->value(), type_factory.GetStringType()); } +TEST_P(ProtoStructTypeTest, NewFieldIteratorIds) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN(auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, type->NewFieldIterator(memory_manager())); + std::set actual_ids; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto id, iterator->NextId(type_manager)); + actual_ids.insert(id); + } + EXPECT_THAT(iterator->NextId(type_manager), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_ids; + const auto* const descriptor = TestAllTypes::descriptor(); + for (int index = 0; index < descriptor->field_count(); ++index) { + expected_ids.insert(StructType::FieldId(descriptor->field(index)->name())); + } + EXPECT_EQ(actual_ids, expected_ids); +} + +TEST_P(ProtoStructTypeTest, NewFieldIteratorName) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN(auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, type->NewFieldIterator(memory_manager())); + std::set actual_names; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto name, iterator->NextName(type_manager)); + actual_names.insert(name); + } + EXPECT_THAT(iterator->NextName(type_manager), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_names; + const auto* const descriptor = TestAllTypes::descriptor(); + for (int index = 0; index < descriptor->field_count(); ++index) { + expected_names.insert(descriptor->field(index)->name()); + } + EXPECT_EQ(actual_names, expected_names); +} + +TEST_P(ProtoStructTypeTest, NewFieldIteratorNumbers) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN(auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, type->NewFieldIterator(memory_manager())); + std::set actual_numbers; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto number, iterator->NextNumber(type_manager)); + actual_numbers.insert(number); + } + EXPECT_THAT(iterator->NextNumber(type_manager), + StatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_numbers; + const auto* const descriptor = TestAllTypes::descriptor(); + for (int index = 0; index < descriptor->field_count(); ++index) { + expected_numbers.insert(descriptor->field(index)->number()); + } + EXPECT_EQ(actual_numbers, expected_numbers); +} + +TEST_P(ProtoStructTypeTest, NewFieldIteratorTypes) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ASSERT_OK_AND_ASSIGN(auto type, + ProtoType::Resolve(type_manager)); + ASSERT_OK_AND_ASSIGN(auto iterator, type->NewFieldIterator(memory_manager())); + absl::flat_hash_set> actual_types; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN(auto type, iterator->NextType(type_manager)); + actual_types.insert(std::move(type)); + } + EXPECT_THAT(iterator->NextType(type_manager), + StatusIs(absl::StatusCode::kFailedPrecondition)); + // We cannot really test actual_types, as hand translating TestAllTypes would + // be obnoxious. Otherwise we would simply be testing the same logic against + // itself, which would not be useful. +} + +INSTANTIATE_TEST_SUITE_P(ProtoStructTypeTest, ProtoStructTypeTest, + cel::base_internal::MemoryManagerTestModeAll(), + cel::base_internal::MemoryManagerTestModeTupleName); + } // namespace } // namespace cel::extensions diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 75b524247..e09590bfa 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -2179,6 +2179,16 @@ std::string ParsedProtoStructValue::DebugString() const { return ParsedProtoStructValue::DebugString(value()); } +size_t ParsedProtoStructValue::field_count() const { + const auto* reflect = value().GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return 0; + } + std::vector fields; + reflect->ListFields(value(), &fields); + return fields.size(); +} + google::protobuf::Message* ParsedProtoStructValue::ValuePointer( google::protobuf::MessageFactory& message_factory, google::protobuf::Arena* arena) const { const auto* desc = value().GetDescriptor(); @@ -2641,6 +2651,60 @@ absl::StatusOr ParsedProtoStructValue::HasField( return reflect->HasField(value(), field_desc); } +class ParsedProtoStructValueFieldIterator final + : public StructValue::FieldIterator { + public: + ParsedProtoStructValueFieldIterator( + const ParsedProtoStructValue* value, + std::vector fields) + : value_(value), fields_(std::move(fields)) {} + + bool HasNext() override { return index_ < fields_.size(); } + + absl::StatusOr Next( + const StructValue::GetFieldContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= fields_.size())) { + return absl::FailedPreconditionError( + "StructValue::FieldIterator::Next() called when " + "StructValue::FieldIterator::HasNext() returns false"); + } + const auto* field = fields_[index_]; + CEL_ASSIGN_OR_RETURN(auto type, value_->type()->FindFieldByNumber( + context.value_factory().type_manager(), + field->number())); + CEL_ASSIGN_OR_RETURN(auto value, + value_->GetField(context, std::move(type).value())); + ++index_; + return Field(StructValue::FieldId(field->name()), std::move(value)); + } + + absl::StatusOr NextId( + const StructValue::GetFieldContext& context) override { + if (ABSL_PREDICT_FALSE(index_ >= fields_.size())) { + return absl::FailedPreconditionError( + "StructValue::FieldIterator::Next() called when " + "StructValue::FieldIterator::HasNext() returns false"); + } + return StructValue::FieldId(fields_[index_++]->name()); + } + + private: + const ParsedProtoStructValue* const value_; + const std::vector fields_; + size_t index_ = 0; +}; + +absl::StatusOr> +ParsedProtoStructValue::NewFieldIterator(MemoryManager& memory_manager) const { + const auto* reflect = value().GetReflection(); + std::vector fields; + if (ABSL_PREDICT_TRUE(reflect != nullptr)) { + reflect->ListFields(value(), &fields); + } + return MakeUnique(memory_manager, this, + std::move(fields)); +} + } // namespace protobuf_internal } // namespace cel::extensions diff --git a/extensions/protobuf/struct_value.h b/extensions/protobuf/struct_value.h index a081bfcaf..1d748dd64 100644 --- a/extensions/protobuf/struct_value.h +++ b/extensions/protobuf/struct_value.h @@ -174,6 +174,8 @@ class ProtoStructValue : public CEL_STRUCT_VALUE_CLASS { namespace protobuf_internal { +class ParsedProtoStructValueFieldIterator; + // Declare here but implemented in value.cc to give ProtoStructValue access to // the conversion logic in value.cc. Creates a borrowed `ListValue` over // `google.protobuf.ListValue`. @@ -237,6 +239,8 @@ class ParsedProtoStructValue : public ProtoStructValue { std::string DebugString() const final; + size_t field_count() const final; + absl::StatusOr> GetFieldByName( const GetFieldContext& context, absl::string_view name) const final; @@ -249,6 +253,9 @@ class ParsedProtoStructValue : public ProtoStructValue { absl::StatusOr HasFieldByNumber(const HasFieldContext& context, int64_t number) const final; + absl::StatusOr> NewFieldIterator( + MemoryManager& memory_manager) const final; + using ProtoStructValue::value; virtual const google::protobuf::Message& value() const = 0; @@ -280,6 +287,9 @@ class ParsedProtoStructValue : public ProtoStructValue { absl::StatusOr HasField(TypeManager& type_manager, const StructType::Field& field) const; + + private: + friend class ParsedProtoStructValueFieldIterator; }; // Implementation of `ParsedProtoStructValue` which knows the concrete type of diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index a3453363a..7590f0182 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -14,7 +14,9 @@ #include "extensions/protobuf/struct_value.h" +#include #include +#include #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -4588,6 +4590,97 @@ TEST_P(ProtoStructValueTest, DynamicRValueDifferentDescriptors) { EXPECT_TRUE(value->Is()); } +TEST_P(ProtoStructValueTest, NewFieldIteratorIds) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.set_single_bool(true); + message.set_single_int32(1); + message.set_single_int64(1); + message.set_single_uint32(1); + message.set_single_uint64(1); + message.set_single_float(1.0); + message.set_single_double(1.0); + message.set_single_bytes("foo"); + message.set_single_string("foo"); + message.set_standalone_enum(TestAllTypes::BAR); + message.mutable_standalone_message()->set_bb(1); + message.mutable_single_duration()->set_seconds(1); + message.mutable_single_timestamp()->set_seconds(1); + }))); + EXPECT_EQ(value->As().field_count(), 13); + ASSERT_OK_AND_ASSIGN(auto iterator, value->As().NewFieldIterator( + memory_manager())); + std::set actual_ids; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto id, iterator->NextId(StructValue::GetFieldContext(value_factory))); + actual_ids.insert(id); + } + EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), + CanonicalStatusIs(absl::StatusCode::kFailedPrecondition)); + std::set expected_ids = { + StructValue::FieldId("single_bool"), + StructValue::FieldId("single_int32"), + StructValue::FieldId("single_int64"), + StructValue::FieldId("single_uint32"), + StructValue::FieldId("single_uint64"), + StructValue::FieldId("single_float"), + StructValue::FieldId("single_double"), + StructValue::FieldId("single_bytes"), + StructValue::FieldId("single_string"), + StructValue::FieldId("standalone_enum"), + StructValue::FieldId("standalone_message"), + StructValue::FieldId("single_duration"), + StructValue::FieldId("single_timestamp")}; + EXPECT_EQ(actual_ids, expected_ids); +} + +TEST_P(ProtoStructValueTest, NewFieldIteratorValues) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoValue::Create(value_factory, + CreateTestMessage([](TestAllTypes& message) { + message.set_single_bool(true); + message.set_single_int32(1); + message.set_single_int64(1); + message.set_single_uint32(1); + message.set_single_uint64(1); + message.set_single_float(1.0); + message.set_single_double(1.0); + message.set_single_bytes("foo"); + message.set_single_string("foo"); + message.set_standalone_enum(TestAllTypes::BAR); + message.mutable_standalone_message()->set_bb(1); + message.mutable_single_duration()->set_seconds(1); + message.mutable_single_timestamp()->set_seconds(1); + }))); + EXPECT_EQ(value->As().field_count(), 13); + ASSERT_OK_AND_ASSIGN(auto iterator, value->As().NewFieldIterator( + memory_manager())); + std::vector> actual_values; + while (iterator->HasNext()) { + ASSERT_OK_AND_ASSIGN( + auto value, + iterator->NextValue(StructValue::GetFieldContext(value_factory))); + actual_values.push_back(std::move(value)); + } + EXPECT_THAT(iterator->NextValue(StructValue::GetFieldContext(value_factory)), + CanonicalStatusIs(absl::StatusCode::kFailedPrecondition)); + // We cannot really test actual_types, as hand translating TestAllTypes would + // be obnoxious. Otherwise we would simply be testing the same logic against + // itself, which would not be useful. +} + INSTANTIATE_TEST_SUITE_P(ProtoStructValueTest, ProtoStructValueTest, cel::base_internal::MemoryManagerTestModeAll(), cel::base_internal::MemoryManagerTestModeTupleName); diff --git a/internal/BUILD b/internal/BUILD index 698db7a30..4cce4b24f 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -73,6 +73,11 @@ cc_test( ], ) +cc_library( + name = "overloaded", + hdrs = ["overloaded.h"], +) + cc_library( name = "status_macros", hdrs = ["status_macros.h"], diff --git a/internal/overloaded.h b/internal/overloaded.h new file mode 100644 index 000000000..8d317d745 --- /dev/null +++ b/internal/overloaded.h @@ -0,0 +1,30 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_OVERLOADED_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_OVERLOADED_H_ + +namespace cel::internal { + +template +struct Overloaded : Ts... { + using Ts::operator()...; +}; + +template +Overloaded(Ts...) -> Overloaded; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_OVERLOADED_H_ From a3219f7cf0c30758bb04a1404bfa097209ca1de2 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Wed, 10 May 2023 16:44:40 +0000 Subject: [PATCH 254/303] Add interface for flat expression builder program optimizers. PiperOrigin-RevId: 530933240 --- eval/compiler/BUILD | 4 +- eval/compiler/constant_folding.cc | 46 ++++++- eval/compiler/constant_folding.h | 44 +++---- eval/compiler/constant_folding_test.cc | 116 ++++++++++++++---- eval/compiler/flat_expr_builder.cc | 65 +++++----- eval/compiler/flat_expr_builder.h | 12 +- eval/compiler/flat_expr_builder_extensions.h | 26 +++- eval/public/BUILD | 1 + .../portable_cel_expr_builder_factory.cc | 13 +- 9 files changed, 232 insertions(+), 95 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 61467f95b..f413a278e 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -21,7 +21,6 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -194,6 +193,7 @@ cc_library( "//base:handle", "//base:kind", "//base:value", + "//base/internal:ast_impl", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", "//eval/internal:errors", @@ -212,7 +212,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -236,7 +235,6 @@ cc_test( "//eval/public:builtin_func_registrar", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", - "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:ast_converters", "//extensions/protobuf:memory_manager", "//internal:status_macros", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 2f30084c1..ceb2a9237 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -12,6 +12,7 @@ #include "base/ast_internal.h" #include "base/function.h" #include "base/handle.h" +#include "base/internal/ast_impl.h" #include "base/kind.h" #include "base/value.h" #include "base/values/bytes_value.h" @@ -49,6 +50,7 @@ using ::google::api::expr::runtime::ExecutionFrame; using ::google::api::expr::runtime::ExecutionPath; using ::google::api::expr::runtime::ExecutionPathView; using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::ProgramOptimizer; using ::google::api::expr::runtime::Resolver; using ::google::api::expr::runtime::builtin::kAnd; using ::google::api::expr::runtime::builtin::kOr; @@ -386,7 +388,37 @@ bool ConstantFoldingTransform::Transform(const Expr& expr, Expr& out_) { return absl::visit(handler, expr.expr_kind()); } -} // namespace +class ConstantFoldingExtension : public ProgramOptimizer { + public: + ConstantFoldingExtension(int stack_limit, google::protobuf::Arena* arena) + : arena_(arena), state_(stack_limit, arena) {} + + absl::Status OnInit(google::api::expr::runtime::PlannerContext& context, + const AstImpl& ast) override { + // Clean up const stack incase of failure in the middle of planning previous + // expression. + is_const_.clear(); + return absl::OkStatus(); + } + + absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node) override; + absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, + const Expr& node) override; + + private: + enum class IsConst { + kConditional, + kNonConst, + }; + + google::protobuf::Arena* arena_; + google::api::expr::runtime::Activation empty_; + google::api::expr::runtime::CelEvaluationListener null_listener_; + google::api::expr::runtime::CelExpressionFlatEvaluationState state_; + + std::vector is_const_; +}; absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, const Expr& node) { @@ -447,6 +479,10 @@ absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, const Expr& node) { + if (is_const_.empty()) { + return absl::InternalError("ConstantFoldingExtension called out of order."); + } + IsConst is_const = is_const_.back(); is_const_.pop_back(); @@ -482,6 +518,8 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, return context.ReplaceSubplan(node, std::move(new_plan)); } +} // namespace + void FoldConstants( const Expr& ast, const FunctionRegistry& registry, google::protobuf::Arena* arena, absl::flat_hash_map>& constant_idents, @@ -490,4 +528,10 @@ void FoldConstants( constant_folder.Transform(ast, out_ast); } +std::unique_ptr +CreateConstantFoldingExtension(google::protobuf::Arena* arena, + ConstantFoldingOptions options) { + return std::make_unique(options.stack_limit, arena); +} + } // namespace cel::ast::internal diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index 65dbbf224..76d5a3f69 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -1,17 +1,13 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ +#include #include -#include #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "base/ast_internal.h" #include "base/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" -#include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" -#include "eval/public/cel_expression.h" #include "runtime/function_registry.h" #include "google/protobuf/arena.h" @@ -25,30 +21,24 @@ void FoldConstants( absl::flat_hash_map>& constant_idents, Expr& out_ast); -class ConstantFoldingExtension { - public: - ConstantFoldingExtension(int stack_limit, google::protobuf::Arena* arena) - : arena_(arena), state_(stack_limit, arena) {} - - absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, - const Expr& node); - absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, - const Expr& node); - - private: - enum class IsConst { - kConditional, - kNonConst, - }; - - google::protobuf::Arena* arena_; - google::api::expr::runtime::Activation empty_; - google::api::expr::runtime::CelEvaluationListener null_listener_; - google::api::expr::runtime::CelExpressionFlatEvaluationState state_; - - std::vector is_const_; +struct ConstantFoldingOptions { + // Stack limit for evaluating constant sub expressions. + // Should accommodate the maximum expected number of dependencies for a small + // subexpression (e.g. number of elements in a list). + // + // 64 is sufficient to support map literals with 32 key/value pairs per the + // minimum required support in the CEL spec. + int stack_limit = 64; }; +// Create a new constant folding extension. +// Eagerly evaluates sub expressions with all constant inputs, and replaces said +// sub expression with the result. +std::unique_ptr +CreateConstantFoldingExtension( + google::protobuf::Arena* arena, + ConstantFoldingOptions options = ConstantFoldingOptions()); + } // namespace cel::ast::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index c03d6abd9..542750346 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -4,7 +4,6 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/text_format.h" #include "base/ast_internal.h" #include "base/internal/ast_impl.h" #include "base/type_factory.h" @@ -23,7 +22,6 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" -#include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/ast_converters.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" @@ -49,9 +47,11 @@ using ::google::api::expr::runtime::CelTypeRegistry; using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::ExecutionPath; using ::google::api::expr::runtime::PlannerContext; +using ::google::api::expr::runtime::ProgramOptimizer; using ::google::api::expr::runtime::Resolver; using ::google::protobuf::Arena; using testing::SizeIs; +using cel::internal::StatusIs; class ConstantFoldingTestWithValueFactory : public testing::Test { public: @@ -606,18 +606,20 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - ConstantFoldingExtension constant_folder(kStackLimit, &arena); + std::unique_ptr constant_folder = + CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder.OnPreVisit(context, call)); - ASSERT_OK(constant_folder.OnPreVisit(context, condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, condition)); - ASSERT_OK(constant_folder.OnPreVisit(context, true_branch)); - ASSERT_OK(constant_folder.OnPostVisit(context, true_branch)); - ASSERT_OK(constant_folder.OnPreVisit(context, false_branch)); - ASSERT_OK(constant_folder.OnPostVisit(context, false_branch)); - ASSERT_OK(constant_folder.OnPostVisit(context, call)); + ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, true_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, true_branch)); + ASSERT_OK(constant_folder->OnPreVisit(context, false_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, false_branch)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. @@ -670,16 +672,18 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - ConstantFoldingExtension constant_folder(kStackLimit, &arena); + std::unique_ptr constant_folder = + CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder.OnPreVisit(context, call)); - ASSERT_OK(constant_folder.OnPreVisit(context, left_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, left_condition)); - ASSERT_OK(constant_folder.OnPreVisit(context, right_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, right_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, call)); + ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. @@ -732,22 +736,84 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - ConstantFoldingExtension constant_folder(kStackLimit, &arena); + std::unique_ptr constant_folder = + CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder.OnPreVisit(context, call)); - ASSERT_OK(constant_folder.OnPreVisit(context, left_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, left_condition)); - ASSERT_OK(constant_folder.OnPreVisit(context, right_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, right_condition)); - ASSERT_OK(constant_folder.OnPostVisit(context, call)); + ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); + ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. EXPECT_THAT(path, SizeIs(3)); } +TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { + // Arrange + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + ParseFromCel("true && false")); + AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); + + const Expr& call = ast_impl.root_expr(); + const Expr& left_condition = call.call_expr().args()[0]; + const Expr& right_condition = call.call_expr().args()[1]; + + PlannerContext::ProgramTree tree; + PlannerContext::ProgramInfo& call_info = tree[&call]; + call_info.range_start = 0; + call_info.range_len = 4; + call_info.children = {&left_condition, &right_condition}; + + PlannerContext::ProgramInfo& left_condition_info = tree[&left_condition]; + left_condition_info.range_start = 0; + left_condition_info.range_len = 1; + left_condition_info.parent = &call; + + PlannerContext::ProgramInfo& right_condition_info = tree[&right_condition]; + right_condition_info.range_start = 1; + right_condition_info.range_len = 1; + right_condition_info.parent = &call; + + // Mock execution path that has placeholders for the non-shortcircuiting + // version of ternary. + ExecutionPath path; + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(true)), -1)); + + ASSERT_OK_AND_ASSIGN(path.emplace_back(), + CreateConstValueStep(Constant(ConstantKind(false)), -1)); + + // Just a placeholder. + ASSERT_OK_AND_ASSIGN( + path.emplace_back(), + CreateConstValueStep(Constant(NullValue::kNullValue), -1)); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + google::protobuf::Arena arena; + constexpr int kStackLimit = 1; + std::unique_ptr constant_folder = + CreateConstantFoldingExtension(&arena, {kStackLimit}); + + // Act + // Issue the visitation calls in wrong order. + ASSERT_OK(constant_folder->OnPreVisit(context, call)); + ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); + ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + + // ASSERT + EXPECT_THAT(constant_folder->OnPostVisit(context, left_condition), + StatusIs(absl::StatusCode::kInternal)); +} + } // namespace } // namespace cel::ast::internal diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1b49917e1..0170f6e2e 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -270,15 +270,15 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { public: FlatExprVisitor( const google::api::expr::runtime::Resolver& resolver, - google::api::expr::runtime::ExecutionPath* path, const cel::RuntimeOptions& options, const absl::flat_hash_map>& constant_idents, - google::protobuf::Arena* constant_arena, bool updated_constant_folding, bool enable_comprehension_vulnerability_check, - google::api::expr::runtime::BuilderWarnings* warnings, bool enable_regex_precompilation, + absl::Span> program_optimizers, const absl::flat_hash_map* reference_map, + google::api::expr::runtime::ExecutionPath* path, + google::api::expr::runtime::BuilderWarnings* warnings, google::protobuf::Arena* arena, PlannerContext::ProgramTree& program_tree, PlannerContext& extension_context) : resolver_(resolver), @@ -288,21 +288,16 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { parent_expr_(nullptr), options_(options), constant_idents_(constant_idents), - constant_arena_(constant_arena), enable_comprehension_vulnerability_check_( enable_comprehension_vulnerability_check), - builder_warnings_(warnings), enable_regex_precompilation_(enable_regex_precompilation), + program_optimizers_(program_optimizers), + builder_warnings_(warnings), regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), arena_(arena), program_tree_(program_tree), - extension_context_(extension_context) { - if (updated_constant_folding) { - constexpr int kDefaultConstFoldStackLimit = 64; - constant_folding_.emplace(kDefaultConstFoldStackLimit, constant_arena_); - } - } + extension_context_(extension_context) {} void PreVisitExpr(const cel::ast::internal::Expr* expr, const cel::ast::internal::SourcePosition*) override { @@ -312,8 +307,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { if (!progress_status_.ok()) { return; } - // TODO(issues/5): this will be generalized later. - if (!(constant_folding_.has_value())) { + if (program_optimizers_.empty()) { return; } PlannerContext::ProgramInfo& info = program_tree_[expr]; @@ -323,10 +317,13 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { program_tree_[parent_expr_].children.push_back(expr); } parent_expr_ = expr; - absl::Status status = - constant_folding_->OnPreVisit(extension_context_, *expr); - if (!status.ok()) { - SetProgressStatusError(status); + + for (const std::unique_ptr& optimizer : + program_optimizers_) { + absl::Status status = optimizer->OnPreVisit(extension_context_, *expr); + if (!status.ok()) { + SetProgressStatusError(status); + } } } @@ -336,16 +333,19 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { return; } // TODO(issues/5): this will be generalized later. - if (!constant_folding_.has_value()) { + if (program_optimizers_.empty()) { return; } PlannerContext::ProgramInfo& info = program_tree_[expr]; info.range_len = execution_path_->size() - info.range_start; parent_expr_ = info.parent; - absl::Status status = - constant_folding_->OnPostVisit(extension_context_, *expr); - if (!status.ok()) { - SetProgressStatusError(status); + + for (const std::unique_ptr& optimizer : + program_optimizers_) { + absl::Status status = optimizer->OnPostVisit(extension_context_, *expr); + if (!status.ok()) { + SetProgressStatusError(status); + } } } @@ -857,16 +857,15 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { const cel::RuntimeOptions& options_; const absl::flat_hash_map>& constant_idents_; - google::protobuf::Arena* constant_arena_; - absl::optional - constant_folding_; std::stack comprehension_stack_; bool enable_comprehension_vulnerability_check_; + bool enable_regex_precompilation_; + + absl::Span> program_optimizers_; google::api::expr::runtime::BuilderWarnings* builder_warnings_; - bool enable_regex_precompilation_; RegexProgramBuilder regex_program_builder_; const absl::flat_hash_map* const reference_map_; @@ -1321,7 +1320,7 @@ FlatExprBuilder::CreateExpressionImpl( } cel::ast::internal::Expr const_fold_buffer; - if (constant_folding_ && !updated_constant_folding_) { + if (constant_folding_) { cel::ast::internal::FoldConstants( ast_impl.root_expr(), this->GetRegistry()->InternalGetRegistry(), constant_arena_, constant_idents, const_fold_buffer); @@ -1330,11 +1329,15 @@ FlatExprBuilder::CreateExpressionImpl( auto arena = std::make_unique(); + for (const std::unique_ptr& optimizer : + program_optimizers_) { + CEL_RETURN_IF_ERROR(optimizer->OnInit(extension_context, ast_impl)); + } FlatExprVisitor visitor( - resolver, &execution_path, options_, constant_idents, constant_arena_, - updated_constant_folding_, enable_comprehension_vulnerability_check_, - &warnings_builder, enable_regex_precompilation_, - &ast_impl.reference_map(), arena.get(), program_tree, extension_context); + resolver, options_, constant_idents, + enable_comprehension_vulnerability_check_, enable_regex_precompilation_, + program_optimizers_, &ast_impl.reference_map(), &execution_path, + &warnings_builder, arena.get(), program_tree, extension_context); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index f2c934293..8898a15ca 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -44,10 +44,10 @@ class FlatExprBuilder : public CelExpressionBuilder { // Toggle constant folding optimization. By default it is not enabled. // The provided arena is used to hold the generated constants. - void set_constant_folding(bool enabled, google::protobuf::Arena* arena, - bool updated = false) { + // TODO(issues/5): default enable the updated version then deprecate this + // function. + void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { constant_folding_ = enabled; - updated_constant_folding_ = updated; constant_arena_ = arena; } @@ -67,6 +67,10 @@ class FlatExprBuilder : public CelExpressionBuilder { ast_transforms_.push_back(std::move(transform)); } + void AddProgramOptimizer(std::unique_ptr optimizer) { + program_optimizers_.push_back(std::move(optimizer)); + } + void set_enable_regex_precompilation(bool enable) { enable_regex_precompilation_ = enable; } @@ -99,11 +103,11 @@ class FlatExprBuilder : public CelExpressionBuilder { cel::RuntimeOptions options_; std::vector> ast_transforms_; + std::vector> program_optimizers_; bool enable_regex_precompilation_ = false; bool enable_comprehension_vulnerability_check_ = false; bool constant_folding_ = false; - bool updated_constant_folding_ = false; google::protobuf::Arena* constant_arena_ = nullptr; }; diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 1a1712ed2..3c0eac70a 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -84,7 +84,7 @@ class PlannerContext { }; // Interface for Ast Transforms. -// If any are present, the flat expr builder will apply the Ast Transforms in +// If any are present, the FlatExprBuilder will apply the Ast Transforms in // order on a copy of the relevant input expressions before planning the // program. class AstTransform { @@ -95,6 +95,30 @@ class AstTransform { cel::ast::internal::AstImpl& ast) const = 0; }; +// Interface for program optimizers. +// +// If any are present, the FlatExprBuilder will notify the implementations in +// order as it traverses the input ast. +// +// Note: implementations must correctly check that subprograms are available +// before accessing (i.e. they have not already been edited). +class ProgramOptimizer { + public: + virtual ~ProgramOptimizer() = default; + + // Called once before program planning begins for the given AST. + virtual absl::Status OnInit(PlannerContext& context, + const cel::ast::internal::AstImpl& ast) = 0; + + // Called before planning the given expr node. + virtual absl::Status OnPreVisit(PlannerContext& context, + const cel::ast::internal::Expr& node) = 0; + + // Called after planning the given expr node. + virtual absl::Status OnPostVisit(PlannerContext& context, + const cel::ast::internal::Expr& node) = 0; +}; + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ diff --git a/eval/public/BUILD b/eval/public/BUILD index 520cd7ccc..eb798643b 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -1197,6 +1197,7 @@ cc_library( deps = [ ":cel_expression", ":cel_options", + "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", "//eval/compiler:qualified_reference_resolver", "//eval/public/structs:legacy_type_provider", diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index e8950af42..5d4bb5de7 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -21,6 +21,7 @@ #include #include "absl/status/status.h" +#include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/public/cel_options.h" @@ -50,9 +51,15 @@ std::unique_ptr CreatePortableExprBuilder( builder->set_enable_comprehension_vulnerability_check( options.enable_comprehension_vulnerability_check); builder->set_enable_regex_precompilation(options.enable_regex_precompilation); - builder->set_constant_folding(options.constant_folding, - options.constant_arena, - options.enable_updated_constant_folding); + + if (options.constant_folding && options.enable_updated_constant_folding) { + builder->AddProgramOptimizer( + cel::ast::internal::CreateConstantFoldingExtension( + options.constant_arena)); + } else { + builder->set_constant_folding(options.constant_folding, + options.constant_arena); + } return builder; } From a1fbbacf0fc5059fb5b93ed315d1b2f72acb7f5f Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 10 May 2023 20:17:47 +0000 Subject: [PATCH 255/303] Internal tool change PiperOrigin-RevId: 530988037 --- eval/testutil/test_message.proto | 1 + 1 file changed, 1 insertion(+) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index bd42dfb6e..3fae9f915 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -66,6 +66,7 @@ message TestMessage { map int32_float_map = 207; map int64_enum_map = 208; map string_timestamp_map = 209; + map string_message_map = 210; // Well-known types. google.protobuf.Any any_value = 300; From a4f8c6f4d372bb31353f12ef8c6db77e59052c55 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 10 May 2023 21:30:17 +0000 Subject: [PATCH 256/303] Fix specialization of `MapValueBuilder::InsertOrUpdate` PiperOrigin-RevId: 531006717 --- base/values/map_value_builder.h | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index 5aa788254..3521e77ca 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -1268,23 +1268,19 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { } absl::StatusOr InsertOrUpdate(const K& key, const Handle& value) { - return storage_.insert_or_assign(std::make_pair(key, value)).second; + return storage_.insert_or_assign(key, value).second; } absl::StatusOr InsertOrUpdate(const K& key, Handle&& value) { - return storage_.insert_or_assign(std::make_pair(key, std::move(value))) - .second; + return storage_.insert_or_assign(key, std::move(value)).second; } absl::StatusOr InsertOrUpdate(K&& key, const Handle& value) { - return storage_.insert_or_assign(std::make_pair(std::move(key), value)) - .second; + return storage_.insert_or_assign(std::move(key), value).second; } absl::StatusOr InsertOrUpdate(K&& key, Handle&& value) { - return storage_ - .insert_or_assign(std::make_pair(std::move(key), std::move(value))) - .second; + return storage_.insert_or_assign(std::move(key), std::move(value)).second; } bool Has(const Handle& key) const override { return Has(key.As()); } From dec5fd588f22d8dd460b96436a98b55b6a64c3ee Mon Sep 17 00:00:00 2001 From: jdtatum Date: Wed, 10 May 2023 21:34:15 +0000 Subject: [PATCH 257/303] Add constant step implementation that is inspectable during program planning. PiperOrigin-RevId: 531007748 --- eval/eval/BUILD | 31 +++++++++ eval/eval/compiler_constant_step.cc | 24 +++++++ eval/eval/compiler_constant_step.h | 49 ++++++++++++++ eval/eval/compiler_constant_step_test.cc | 86 ++++++++++++++++++++++++ eval/eval/const_value_step.cc | 11 ++- eval/eval/const_value_step.h | 1 - eval/eval/evaluator_core.h | 8 ++- eval/eval/evaluator_core_test.cc | 8 +++ eval/eval/expression_step_base.h | 4 ++ 9 files changed, 213 insertions(+), 9 deletions(-) create mode 100644 eval/eval/compiler_constant_step.cc create mode 100644 eval/eval/compiler_constant_step.h create mode 100644 eval/eval/compiler_constant_step_test.cc diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 790d5b183..43df77339 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -33,6 +33,7 @@ cc_library( "//eval/public:unknown_attribute_set", "//extensions/protobuf:memory_manager", "//internal:casts", + "//internal:rtti", "//internal:status_macros", "//runtime:activation_interface", "//runtime:runtime_options", @@ -95,6 +96,7 @@ cc_library( "const_value_step.h", ], deps = [ + ":compiler_constant_step", ":evaluator_core", ":expression_step_base", "//base:ast_internal", @@ -835,3 +837,32 @@ cc_library( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "compiler_constant_step", + srcs = ["compiler_constant_step.cc"], + hdrs = ["compiler_constant_step.h"], + deps = [ + ":expression_step_base", + "//internal:rtti", + ], +) + +cc_test( + name = "compiler_constant_step_test", + srcs = ["compiler_constant_step_test.cc"], + deps = [ + ":compiler_constant_step", + ":evaluator_core", + ":test_type_registry", + "//base:type", + "//base:value", + "//eval/public:activation", + "//eval/public:cel_expression", + "//extensions/protobuf:memory_manager", + "//internal:rtti", + "//internal:status_macros", + "//internal:testing", + "//runtime:runtime_options", + ], +) diff --git a/eval/eval/compiler_constant_step.cc b/eval/eval/compiler_constant_step.cc new file mode 100644 index 000000000..9933dd06b --- /dev/null +++ b/eval/eval/compiler_constant_step.cc @@ -0,0 +1,24 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/compiler_constant_step.h" + +namespace google::api::expr::runtime { + +absl::Status CompilerConstantStep::Evaluate(ExecutionFrame* frame) const { + frame->value_stack().Push(value_); + + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/compiler_constant_step.h b/eval/eval/compiler_constant_step.h new file mode 100644 index 000000000..26a7f5886 --- /dev/null +++ b/eval/eval/compiler_constant_step.h @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ + +#include + +#include "eval/eval/expression_step_base.h" +#include "internal/rtti.h" + +namespace google::api::expr::runtime { + +// ExpressionStep implementation that simply pushes a constant value on the +// stack. +// +// Overrides TypeInfo to allow the FlatExprBuilder and extensions to inspect +// the underlying value. +class CompilerConstantStep : public ExpressionStepBase { + public: + CompilerConstantStep(cel::Handle value, int64_t expr_id, + bool comes_from_ast) + : ExpressionStepBase(expr_id, comes_from_ast), value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + cel::internal::TypeInfo TypeId() const override { + return cel::internal::TypeId(); + } + + const cel::Handle& value() const { return value_; } + + private: + cel::Handle value_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ diff --git a/eval/eval/compiler_constant_step_test.cc b/eval/eval/compiler_constant_step_test.cc new file mode 100644 index 000000000..cc48f296e --- /dev/null +++ b/eval/eval/compiler_constant_step_test.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/compiler_constant_step.h" + +#include + +#include "base/type_factory.h" +#include "base/type_manager.h" +#include "base/value_factory.h" +#include "base/values/int_value.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/test_type_registry.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/rtti.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "runtime/runtime_options.h" + +namespace google::api::expr::runtime { + +namespace { + +class CompilerConstantStepTest : public testing::Test { + public: + CompilerConstantStepTest() + : memory_manager_(&arena_), + type_factory_(memory_manager_), + type_manager_(type_factory_, cel::TypeProvider::Builtin()), + value_factory_(type_manager_), + state_(2, &arena_) {} + + protected: + google::protobuf::Arena arena_; + cel::extensions::ProtoMemoryManager memory_manager_; + cel::TypeFactory type_factory_; + cel::TypeManager type_manager_; + cel::ValueFactory value_factory_; + + CelExpressionFlatEvaluationState state_; + Activation empty_activation_; + cel::RuntimeOptions options_; +}; + +TEST_F(CompilerConstantStepTest, Evaluate) { + ExecutionPath path; + path.push_back(std::make_unique( + value_factory_.CreateIntValue(42), -1, false)); + + ExecutionFrame frame(path, empty_activation_, &TestTypeRegistry(), options_, + &state_); + + ASSERT_OK_AND_ASSIGN(cel::Handle result, + frame.Evaluate(CelEvaluationListener())); + + EXPECT_EQ(result->As().value(), 42); +} + +TEST_F(CompilerConstantStepTest, TypeId) { + CompilerConstantStep step(value_factory_.CreateIntValue(42), -1, false); + + ExpressionStep& abstract_step = step; + EXPECT_EQ(abstract_step.TypeId(), + cel::internal::TypeId()); +} + +TEST_F(CompilerConstantStepTest, Value) { + CompilerConstantStep step(value_factory_.CreateIntValue(42), -1, false); + + EXPECT_EQ(step.value()->As().value(), 42); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 4d42baf8c..8c10a4c68 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -8,6 +8,7 @@ #include "absl/status/statusor.h" #include "absl/time/time.h" #include "base/ast_internal.h" +#include "eval/eval/compiler_constant_step.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/interop.h" @@ -19,10 +20,6 @@ using ::cel::ast::internal::Constant; class ConstValueStep : public ExpressionStepBase { public: - ConstValueStep(cel::Handle value, int64_t expr_id, - bool comes_from_ast) - : ExpressionStepBase(expr_id, comes_from_ast), value_(std::move(value)) {} - ConstValueStep(const Constant& expr, int64_t expr_id, bool comes_from_ast) : ExpressionStepBase(expr_id, comes_from_ast), const_expr_(expr), @@ -31,7 +28,7 @@ class ConstValueStep : public ExpressionStepBase { absl::Status Evaluate(ExecutionFrame* frame) const override; private: - // Mainain a copy of the source constant to avoid lifecycle dependence on the + // Maintain a copy of the source constant to avoid lifecycle dependence on the // ast after planning. cel::ast::internal::Constant const_expr_; cel::Handle value_; @@ -82,8 +79,8 @@ cel::Handle ConvertConstant( absl::StatusOr> CreateConstValueStep( cel::Handle value, int64_t expr_id, bool comes_from_ast) { - return std::make_unique(std::move(value), expr_id, - comes_from_ast); + return std::make_unique(std::move(value), expr_id, + comes_from_ast); } absl::StatusOr> CreateConstValueStep( diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index a2e71d8b7..15ae6408f 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -9,7 +9,6 @@ #include "base/handle.h" #include "base/value.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/cel_value.h" namespace google::api::expr::runtime { diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 3045c0419..b10bb4efb 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -39,6 +39,7 @@ #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "extensions/protobuf/memory_manager.h" +#include "internal/rtti.h" #include "runtime/activation_interface.h" #include "runtime/runtime_options.h" @@ -73,6 +74,11 @@ class ExpressionStep { // Returns if the execution step comes from AST. virtual bool ComesFromAst() const = 0; + + // Return the type of the underlying expression step for special handling in + // the planning phase. This should only be overridden by special cases, and + // callers must not make any assumptions about the default case. + virtual cel::internal::TypeInfo TypeId() const = 0; }; using ExecutionPath = std::vector>; @@ -315,7 +321,7 @@ class CelExpressionFlatImpl : public CelExpression { CelEvaluationListener callback) const override; private: - // Arena used while builting the expression, must live as long. + // Arena used while building the expression, must live as long. const std::unique_ptr arena_; const ExecutionPath path_; const CelTypeRegistry& type_registry_; diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index fa3017f53..52ec09f01 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -39,6 +39,10 @@ class FakeConstExpressionStep : public ExpressionStep { int64_t id() const override { return 0; } bool ComesFromAst() const override { return true; } + + cel::internal::TypeInfo TypeId() const override { + return cel::internal::TypeInfo(); + } }; // Fake expression implementation @@ -58,6 +62,10 @@ class FakeIncrementExpressionStep : public ExpressionStep { int64_t id() const override { return 0; } bool ComesFromAst() const override { return true; } + + cel::internal::TypeInfo TypeId() const override { + return cel::internal::TypeInfo(); + } }; TEST(EvaluatorCoreTest, ExecutionFrameNext) { diff --git a/eval/eval/expression_step_base.h b/eval/eval/expression_step_base.h index 58353aabf..b8341a1f1 100644 --- a/eval/eval/expression_step_base.h +++ b/eval/eval/expression_step_base.h @@ -22,6 +22,10 @@ class ExpressionStepBase : public ExpressionStep { // Returns if the execution step comes from AST. bool ComesFromAst() const override { return comes_from_ast_; } + cel::internal::TypeInfo TypeId() const override { + return cel::internal::TypeInfo(); + } + private: int64_t id_; bool comes_from_ast_; From 07c3378b1f8a64c3eafe9b7db591205853635441 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 10 May 2023 21:45:15 +0000 Subject: [PATCH 258/303] Deduplicate `ProtoStructValue::HasField` tests PiperOrigin-RevId: 531010407 --- extensions/protobuf/struct_value_test.cc | 1183 ++++------------------ 1 file changed, 169 insertions(+), 1014 deletions(-) diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index 7590f0182..8a87b5f0f 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -90,54 +90,33 @@ T Must(absl::StatusOr status_or) { return Must(std::move(status_or).value()); } -TEST_P(ProtoStructValueTest, NullValueHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("null_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_null_value(NULL_VALUE); - }))); - // In proto3, this can never be present as it will always be the default - // value. We would need to add `optional` for it to work. - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("null_value")), - IsOkAndHolds(Eq(false))); -} - -TEST_P(ProtoStructValueTest, OptionalNullValueHasField) { - TypeFactory type_factory(memory_manager()); +// Implementation for ProtoStructValue::HasField. This should be the one and +// only call in the function body. +// +// NOTE: Explore using parameter generator approach instead. +template +void TestHasField(MemoryManager& memory_manager, ProtoStructType::FieldId id, + TestMessageMaker&& test_message_maker) { + TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("optional_null_value")), + value_without->HasField(StructValue::HasFieldContext(type_manager), id), IsOkAndHolds(Eq(false))); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_optional_null_value(NULL_VALUE); - }))); + CreateTestMessage(std::forward( + test_message_maker)))); EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("optional_null_value")), + value_with->HasField(StructValue::HasFieldContext(type_manager), id), IsOkAndHolds(Eq(true))); } -TEST_P(ProtoStructValueTest, BoolHasField) { +TEST_P(ProtoStructValueTest, NullValueHasField) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -146,1143 +125,319 @@ TEST_P(ProtoStructValueTest, BoolHasField) { ProtoValue::Create(value_factory, CreateTestMessage())); EXPECT_THAT( value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bool")), + ProtoStructType::FieldId("null_value")), IsOkAndHolds(Eq(false))); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, CreateTestMessage([](TestAllTypes& message) { - message.set_single_bool(true); + message.set_null_value(NULL_VALUE); }))); + // In proto3, this can never be present as it will always be the default + // value. We would need to add `optional` for it to work. EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bool")), - IsOkAndHolds(Eq(true))); + ProtoStructType::FieldId("null_value")), + IsOkAndHolds(Eq(false))); +} + +TEST_P(ProtoStructValueTest, OptionalNullValueHasField) { + TestHasField(memory_manager(), + ProtoStructType::FieldId("optional_null_value"), + [](TestAllTypes& message) { + message.set_optional_null_value(NULL_VALUE); + }); +} + +TEST_P(ProtoStructValueTest, BoolHasField) { + TestHasField(memory_manager(), ProtoStructType::FieldId("single_bool"), + [](TestAllTypes& message) { message.set_single_bool(true); }); } TEST_P(ProtoStructValueTest, Int32HasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int32")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_int32(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int32")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_int32"), + [](TestAllTypes& message) { message.set_single_int32(1); }); } TEST_P(ProtoStructValueTest, Int64HasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int64")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_int64(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int64")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_int64"), + [](TestAllTypes& message) { message.set_single_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32HasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint32")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_uint32(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint32")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_uint32"), + [](TestAllTypes& message) { message.set_single_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64HasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint64")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_uint64(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint64")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_uint64"), + [](TestAllTypes& message) { message.set_single_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_float")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_float(1.0); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_float")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_float"), + [](TestAllTypes& message) { message.set_single_float(1.0); }); } TEST_P(ProtoStructValueTest, DoubleHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_double")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_double(1.0); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_double")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_double"), + [](TestAllTypes& message) { message.set_single_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bytes")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_bytes("foo"); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bytes")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_bytes"), + [](TestAllTypes& message) { message.set_single_bytes("foo"); }); } TEST_P(ProtoStructValueTest, StringHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_string")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_string("foo"); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_string")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_string"), + [](TestAllTypes& message) { message.set_single_string("foo"); }); } TEST_P(ProtoStructValueTest, DurationHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_duration")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_duration(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_duration")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_duration"), + [](TestAllTypes& message) { message.mutable_single_duration(); }); } TEST_P(ProtoStructValueTest, TimestampHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_timestamp")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_timestamp(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_timestamp")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_timestamp"), + [](TestAllTypes& message) { message.mutable_single_timestamp(); }); } TEST_P(ProtoStructValueTest, EnumHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("standalone_enum")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_standalone_enum(TestAllTypes::BAR); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("standalone_enum")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("standalone_enum"), + [](TestAllTypes& message) { + message.set_standalone_enum(TestAllTypes::BAR); + }); } TEST_P(ProtoStructValueTest, MessageHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("standalone_message")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_standalone_message(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("standalone_message")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("standalone_message"), + [](TestAllTypes& message) { message.mutable_standalone_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bool_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_bool_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bool_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_bool_wrapper"), + [](TestAllTypes& message) { message.mutable_single_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int32_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_int32_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int32_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_int32_wrapper"), + [](TestAllTypes& message) { message.mutable_single_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int64_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_int64_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_int64_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_int64_wrapper"), + [](TestAllTypes& message) { message.mutable_single_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt32WrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint32_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_uint32_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint32_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_uint32_wrapper"), + [](TestAllTypes& message) { message.mutable_single_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt64WrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint64_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_uint64_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_uint64_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_uint64_wrapper"), + [](TestAllTypes& message) { message.mutable_single_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_float_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_float_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_float_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_float_wrapper"), + [](TestAllTypes& message) { message.mutable_single_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_double_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_double_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_double_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_double_wrapper"), + [](TestAllTypes& message) { message.mutable_single_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bytes_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_bytes_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_bytes_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_bytes_wrapper"), + [](TestAllTypes& message) { message.mutable_single_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_string_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_string_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_string_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("single_string_wrapper"), + [](TestAllTypes& message) { message.mutable_single_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("list_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_list_value(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("list_value")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("list_value"), + [](TestAllTypes& message) { message.mutable_list_value(); }); } TEST_P(ProtoStructValueTest, StructHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_struct")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_struct(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_struct")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_struct"), + [](TestAllTypes& message) { message.mutable_single_struct(); }); } TEST_P(ProtoStructValueTest, ValueHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_value(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("single_value")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("single_value"), + [](TestAllTypes& message) { message.mutable_single_value(); }); } TEST_P(ProtoStructValueTest, NullValueListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_null_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_null_value(NULL_VALUE); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_null_value")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), + ProtoStructType::FieldId("repeated_null_value"), + [](TestAllTypes& message) { + message.add_repeated_null_value(NULL_VALUE); + }); } TEST_P(ProtoStructValueTest, BoolListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bool")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bool(true); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bool")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_bool"), + [](TestAllTypes& message) { message.add_repeated_bool(true); }); } TEST_P(ProtoStructValueTest, Int32ListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int32")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int32(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int32")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_int32"), + [](TestAllTypes& message) { message.add_repeated_int32(true); }); } TEST_P(ProtoStructValueTest, Int64ListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int64")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int64(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int64")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_int64"), + [](TestAllTypes& message) { message.add_repeated_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32ListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint32")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint32(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint32")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_uint32"), + [](TestAllTypes& message) { message.add_repeated_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64ListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint64")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint64(1); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint64")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_uint64"), + [](TestAllTypes& message) { message.add_repeated_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_float")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_float(1.0); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_float")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_float"), + [](TestAllTypes& message) { message.add_repeated_float(1.0); }); } TEST_P(ProtoStructValueTest, DoubleListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_double")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_double(1.0); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_double")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_double"), + [](TestAllTypes& message) { message.add_repeated_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bytes")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bytes("foo"); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bytes")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_bytes"), + [](TestAllTypes& message) { message.add_repeated_bytes("foo"); }); } TEST_P(ProtoStructValueTest, StringListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_string")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_string("foo"); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_string")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_string"), + [](TestAllTypes& message) { message.add_repeated_string("foo"); }); } TEST_P(ProtoStructValueTest, DurationListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_duration")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_duration()->set_seconds(1); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_duration")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_duration"), + [](TestAllTypes& message) { + message.add_repeated_duration()->set_seconds(1); + }); } TEST_P(ProtoStructValueTest, TimestampListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_timestamp")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_timestamp()->set_seconds(1); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_timestamp")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_timestamp"), + [](TestAllTypes& message) { + message.add_repeated_timestamp()->set_seconds(1); + }); } TEST_P(ProtoStructValueTest, EnumListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_nested_enum")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_nested_enum(TestAllTypes::BAR); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_nested_enum")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), + ProtoStructType::FieldId("repeated_nested_enum"), + [](TestAllTypes& message) { + message.add_repeated_nested_enum(TestAllTypes::BAR); + }); } TEST_P(ProtoStructValueTest, MessageListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_nested_message")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_nested_message(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_nested_message")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_nested_message"), + [](TestAllTypes& message) { message.add_repeated_nested_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bool_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bool_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bool_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_bool_wrapper"), + [](TestAllTypes& message) { message.add_repeated_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int32_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int32_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int32_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_int32_wrapper"), + [](TestAllTypes& message) { message.add_repeated_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int64_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int64_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_int64_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_int64_wrapper"), + [](TestAllTypes& message) { message.add_repeated_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint32WrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint32_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint32_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint32_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_uint32_wrapper"), + [](TestAllTypes& message) { message.add_repeated_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint64WrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint64_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint64_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_uint64_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_uint64_wrapper"), + [](TestAllTypes& message) { message.add_repeated_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_float_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_float_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_float_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_float_wrapper"), + [](TestAllTypes& message) { message.add_repeated_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_double_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_double_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_double_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_double_wrapper"), + [](TestAllTypes& message) { message.add_repeated_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bytes_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bytes_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_bytes_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_bytes_wrapper"), + [](TestAllTypes& message) { message.add_repeated_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasField( - StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_string_wrapper")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_string_wrapper(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_string_wrapper")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_string_wrapper"), + [](TestAllTypes& message) { message.add_repeated_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_list_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_list_value(); - }))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_list_value")), - IsOkAndHolds(Eq(true))); + TestHasField( + memory_manager(), ProtoStructType::FieldId("repeated_list_value"), + [](TestAllTypes& message) { message.add_repeated_list_value(); }); } TEST_P(ProtoStructValueTest, StructListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_struct")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_struct(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_struct")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_struct"), + [](TestAllTypes& message) { message.add_repeated_struct(); }); } TEST_P(ProtoStructValueTest, ValueListHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_value(); - }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("repeated_value")), - IsOkAndHolds(Eq(true))); + TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_value"), + [](TestAllTypes& message) { message.add_repeated_value(); }); } TEST_P(ProtoStructValueTest, NullValueGetField) { From 91c808f2117e85620430c1dfa7c8adf4085a3dfc Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 11 May 2023 13:23:38 +0000 Subject: [PATCH 259/303] Yet more test simplifications PiperOrigin-RevId: 531182910 --- extensions/protobuf/BUILD | 1 + extensions/protobuf/struct_value_test.cc | 1270 ++++++++-------------- 2 files changed, 482 insertions(+), 789 deletions(-) diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index d4e12702d..c7b24fa94 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -194,6 +194,7 @@ cc_test( "//base/testing:value_matchers", "//extensions/protobuf/internal:testing", "//internal:testing", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index 8a87b5f0f..951873fe3 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -21,6 +21,7 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/time/time.h" #include "absl/types/optional.h" @@ -43,6 +44,7 @@ namespace cel::extensions { namespace { +using FieldId = ::cel::extensions::ProtoStructType::FieldId; using ::cel_testing::ValueOf; using testing::Eq; using testing::EqualsProto; @@ -90,13 +92,9 @@ T Must(absl::StatusOr status_or) { return Must(std::move(status_or).value()); } -// Implementation for ProtoStructValue::HasField. This should be the one and -// only call in the function body. -// -// NOTE: Explore using parameter generator approach instead. template void TestHasField(MemoryManager& memory_manager, ProtoStructType::FieldId id, - TestMessageMaker&& test_message_maker) { + TestMessageMaker&& test_message_maker, bool found = true) { TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -113,335 +111,330 @@ void TestHasField(MemoryManager& memory_manager, ProtoStructType::FieldId id, test_message_maker)))); EXPECT_THAT( value_with->HasField(StructValue::HasFieldContext(type_manager), id), - IsOkAndHolds(Eq(true))); + IsOkAndHolds(Eq(found))); } +#define TEST_HAS_FIELD(...) ASSERT_NO_FATAL_FAILURE(TestHasField(__VA_ARGS__)) + TEST_P(ProtoStructValueTest, NullValueHasField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("null_value")), - IsOkAndHolds(Eq(false))); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_null_value(NULL_VALUE); - }))); // In proto3, this can never be present as it will always be the default // value. We would need to add `optional` for it to work. - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId("null_value")), - IsOkAndHolds(Eq(false))); + TEST_HAS_FIELD( + memory_manager(), FieldId("null_value"), + [](TestAllTypes& message) { message.set_null_value(NULL_VALUE); }, false); } TEST_P(ProtoStructValueTest, OptionalNullValueHasField) { - TestHasField(memory_manager(), - ProtoStructType::FieldId("optional_null_value"), - [](TestAllTypes& message) { - message.set_optional_null_value(NULL_VALUE); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("optional_null_value"), + [](TestAllTypes& message) { + message.set_optional_null_value(NULL_VALUE); + }); } TEST_P(ProtoStructValueTest, BoolHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_bool"), - [](TestAllTypes& message) { message.set_single_bool(true); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_bool"), + [](TestAllTypes& message) { message.set_single_bool(true); }); } TEST_P(ProtoStructValueTest, Int32HasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_int32"), - [](TestAllTypes& message) { message.set_single_int32(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_int32"), + [](TestAllTypes& message) { message.set_single_int32(1); }); } TEST_P(ProtoStructValueTest, Int64HasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_int64"), - [](TestAllTypes& message) { message.set_single_int64(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_int64"), + [](TestAllTypes& message) { message.set_single_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32HasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_uint32"), - [](TestAllTypes& message) { message.set_single_uint32(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_uint32"), + [](TestAllTypes& message) { message.set_single_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64HasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_uint64"), - [](TestAllTypes& message) { message.set_single_uint64(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_uint64"), + [](TestAllTypes& message) { message.set_single_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_float"), - [](TestAllTypes& message) { message.set_single_float(1.0); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_float"), + [](TestAllTypes& message) { message.set_single_float(1.0); }); } TEST_P(ProtoStructValueTest, DoubleHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_double"), - [](TestAllTypes& message) { message.set_single_double(1.0); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_double"), + [](TestAllTypes& message) { message.set_single_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_bytes"), - [](TestAllTypes& message) { message.set_single_bytes("foo"); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("single_bytes"), + [](TestAllTypes& message) { message.set_single_bytes("foo"); }); } TEST_P(ProtoStructValueTest, StringHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_string"), - [](TestAllTypes& message) { message.set_single_string("foo"); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("single_string"), + [](TestAllTypes& message) { message.set_single_string("foo"); }); } TEST_P(ProtoStructValueTest, DurationHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_duration"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_duration"), [](TestAllTypes& message) { message.mutable_single_duration(); }); } TEST_P(ProtoStructValueTest, TimestampHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_timestamp"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_timestamp"), [](TestAllTypes& message) { message.mutable_single_timestamp(); }); } TEST_P(ProtoStructValueTest, EnumHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("standalone_enum"), - [](TestAllTypes& message) { - message.set_standalone_enum(TestAllTypes::BAR); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("standalone_enum"), + [](TestAllTypes& message) { + message.set_standalone_enum(TestAllTypes::BAR); + }); } TEST_P(ProtoStructValueTest, MessageHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("standalone_message"), + TEST_HAS_FIELD( + memory_manager(), FieldId("standalone_message"), [](TestAllTypes& message) { message.mutable_standalone_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_bool_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_bool_wrapper"), [](TestAllTypes& message) { message.mutable_single_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_int32_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_int32_wrapper"), [](TestAllTypes& message) { message.mutable_single_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_int64_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_int64_wrapper"), [](TestAllTypes& message) { message.mutable_single_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt32WrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_uint32_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_uint32_wrapper"), [](TestAllTypes& message) { message.mutable_single_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt64WrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_uint64_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_uint64_wrapper"), [](TestAllTypes& message) { message.mutable_single_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_float_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_float_wrapper"), [](TestAllTypes& message) { message.mutable_single_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_double_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_double_wrapper"), [](TestAllTypes& message) { message.mutable_single_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_bytes_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_bytes_wrapper"), [](TestAllTypes& message) { message.mutable_single_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("single_string_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("single_string_wrapper"), [](TestAllTypes& message) { message.mutable_single_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("list_value"), - [](TestAllTypes& message) { message.mutable_list_value(); }); + TEST_HAS_FIELD(memory_manager(), FieldId("list_value"), + [](TestAllTypes& message) { message.mutable_list_value(); }); } TEST_P(ProtoStructValueTest, StructHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_struct"), - [](TestAllTypes& message) { message.mutable_single_struct(); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("single_struct"), + [](TestAllTypes& message) { message.mutable_single_struct(); }); } TEST_P(ProtoStructValueTest, ValueHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("single_value"), - [](TestAllTypes& message) { message.mutable_single_value(); }); + TEST_HAS_FIELD(memory_manager(), FieldId("single_value"), + [](TestAllTypes& message) { message.mutable_single_value(); }); } TEST_P(ProtoStructValueTest, NullValueListHasField) { - TestHasField(memory_manager(), - ProtoStructType::FieldId("repeated_null_value"), - [](TestAllTypes& message) { - message.add_repeated_null_value(NULL_VALUE); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_null_value"), + [](TestAllTypes& message) { + message.add_repeated_null_value(NULL_VALUE); + }); } TEST_P(ProtoStructValueTest, BoolListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_bool"), - [](TestAllTypes& message) { message.add_repeated_bool(true); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_bool"), + [](TestAllTypes& message) { message.add_repeated_bool(true); }); } TEST_P(ProtoStructValueTest, Int32ListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_int32"), - [](TestAllTypes& message) { message.add_repeated_int32(true); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_int32"), + [](TestAllTypes& message) { message.add_repeated_int32(true); }); } TEST_P(ProtoStructValueTest, Int64ListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_int64"), - [](TestAllTypes& message) { message.add_repeated_int64(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_int64"), + [](TestAllTypes& message) { message.add_repeated_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32ListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_uint32"), - [](TestAllTypes& message) { message.add_repeated_uint32(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_uint32"), + [](TestAllTypes& message) { message.add_repeated_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64ListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_uint64"), - [](TestAllTypes& message) { message.add_repeated_uint64(1); }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_uint64"), + [](TestAllTypes& message) { message.add_repeated_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_float"), - [](TestAllTypes& message) { message.add_repeated_float(1.0); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_float"), + [](TestAllTypes& message) { message.add_repeated_float(1.0); }); } TEST_P(ProtoStructValueTest, DoubleListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_double"), - [](TestAllTypes& message) { message.add_repeated_double(1.0); }); + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_double"), + [](TestAllTypes& message) { message.add_repeated_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_bytes"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_bytes"), [](TestAllTypes& message) { message.add_repeated_bytes("foo"); }); } TEST_P(ProtoStructValueTest, StringListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_string"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_string"), [](TestAllTypes& message) { message.add_repeated_string("foo"); }); } TEST_P(ProtoStructValueTest, DurationListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_duration"), - [](TestAllTypes& message) { - message.add_repeated_duration()->set_seconds(1); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_duration"), + [](TestAllTypes& message) { + message.add_repeated_duration()->set_seconds(1); + }); } TEST_P(ProtoStructValueTest, TimestampListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_timestamp"), - [](TestAllTypes& message) { - message.add_repeated_timestamp()->set_seconds(1); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_timestamp"), + [](TestAllTypes& message) { + message.add_repeated_timestamp()->set_seconds(1); + }); } TEST_P(ProtoStructValueTest, EnumListHasField) { - TestHasField(memory_manager(), - ProtoStructType::FieldId("repeated_nested_enum"), - [](TestAllTypes& message) { - message.add_repeated_nested_enum(TestAllTypes::BAR); - }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_nested_enum"), + [](TestAllTypes& message) { + message.add_repeated_nested_enum(TestAllTypes::BAR); + }); } TEST_P(ProtoStructValueTest, MessageListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_nested_message"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_nested_message"), [](TestAllTypes& message) { message.add_repeated_nested_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_bool_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_bool_wrapper"), [](TestAllTypes& message) { message.add_repeated_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_int32_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_int32_wrapper"), [](TestAllTypes& message) { message.add_repeated_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_int64_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_int64_wrapper"), [](TestAllTypes& message) { message.add_repeated_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint32WrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_uint32_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_uint32_wrapper"), [](TestAllTypes& message) { message.add_repeated_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint64WrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_uint64_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_uint64_wrapper"), [](TestAllTypes& message) { message.add_repeated_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_float_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_float_wrapper"), [](TestAllTypes& message) { message.add_repeated_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_double_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_double_wrapper"), [](TestAllTypes& message) { message.add_repeated_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_bytes_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_bytes_wrapper"), [](TestAllTypes& message) { message.add_repeated_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_string_wrapper"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_string_wrapper"), [](TestAllTypes& message) { message.add_repeated_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueListHasField) { - TestHasField( - memory_manager(), ProtoStructType::FieldId("repeated_list_value"), + TEST_HAS_FIELD( + memory_manager(), FieldId("repeated_list_value"), [](TestAllTypes& message) { message.add_repeated_list_value(); }); } TEST_P(ProtoStructValueTest, StructListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_struct"), - [](TestAllTypes& message) { message.add_repeated_struct(); }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_struct"), + [](TestAllTypes& message) { message.add_repeated_struct(); }); } TEST_P(ProtoStructValueTest, ValueListHasField) { - TestHasField(memory_manager(), ProtoStructType::FieldId("repeated_value"), - [](TestAllTypes& message) { message.add_repeated_value(); }); + TEST_HAS_FIELD(memory_manager(), FieldId("repeated_value"), + [](TestAllTypes& message) { message.add_repeated_value(); }); } -TEST_P(ProtoStructValueTest, NullValueGetField) { - TypeFactory type_factory(memory_manager()); +void TestGetField( + MemoryManager& memory_manager, FieldId id, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ValueFactory value_factory(type_manager); @@ -449,374 +442,232 @@ TEST_P(ProtoStructValueTest, NullValueGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("null_value"))); - EXPECT_TRUE(field->Is()); + value_without->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_NO_FATAL_FAILURE(unset_field_tester(field)); ASSERT_OK_AND_ASSIGN( auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_null_value(NULL_VALUE); - }))); + ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("null_value"))); - EXPECT_TRUE(field->Is()); + field, + value_with->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_NO_FATAL_FAILURE(set_field_tester(value_factory, field)); +} + +void TestGetField( + MemoryManager& memory_manager, FieldId id, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> set_field_tester) { + TestGetField(memory_manager, id, unset_field_tester, test_message_maker, + [&](ValueFactory& value_factory, const Handle& field) { + set_field_tester(field); + }); +} + +#define TEST_GET_FIELD(...) ASSERT_NO_FATAL_FAILURE(TestGetField(__VA_ARGS__)) + +TEST_P(ProtoStructValueTest, NullValueGetField) { + TEST_GET_FIELD( + memory_manager(), FieldId("null_value"), + [](const Handle& field) { EXPECT_TRUE(field->Is()); }, + [](TestAllTypes& message) { message.set_null_value(NULL_VALUE); }, + [](const Handle& field) { EXPECT_TRUE(field->Is()); }); } TEST_P(ProtoStructValueTest, OptionalNullValueGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("optional_null_value"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_optional_null_value(NULL_VALUE); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("optional_null_value"))); - EXPECT_TRUE(field->Is()); + TEST_GET_FIELD( + memory_manager(), FieldId("optional_null_value"), + [](const Handle& field) { EXPECT_TRUE(field->Is()); }, + [](TestAllTypes& message) { + message.set_optional_null_value(NULL_VALUE); + }, + [](const Handle& field) { EXPECT_TRUE(field->Is()); }); } TEST_P(ProtoStructValueTest, BoolGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bool"))); - EXPECT_FALSE(field.As()->value()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_bool(true); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bool"))); - EXPECT_TRUE(field.As()->value()); + TEST_GET_FIELD( + memory_manager(), FieldId("single_bool"), + [](const Handle& field) { + EXPECT_FALSE(field.As()->value()); + }, + [](TestAllTypes& message) { message.set_single_bool(true); }, + [](const Handle& field) { + EXPECT_TRUE(field.As()->value()); + }); } TEST_P(ProtoStructValueTest, Int32GetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int32"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_int32(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int32"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_int32"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_int32(1); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Int64GetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int64"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_int64(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int64"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_int64"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_int64(1); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Uint32GetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint32"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_uint32(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint32"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_uint32"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_uint32(1); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Uint64GetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint64"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_uint64(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint64"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_uint64"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_uint64(1); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, FloatGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_float"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_float(1.0); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_float"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_float"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_float(1.0); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, DoubleGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_double"))); - EXPECT_EQ(field.As()->value(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_double(1.0); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_double"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("single_double"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { message.set_single_double(1.0); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, BytesGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bytes"))); - EXPECT_EQ(field.As()->ToString(), ""); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_bytes("foo"); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bytes"))); - EXPECT_EQ(field.As()->ToString(), "foo"); + TEST_GET_FIELD( + memory_manager(), FieldId("single_bytes"), + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), ""); + }, + [](TestAllTypes& message) { message.set_single_bytes("foo"); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), "foo"); + }); } TEST_P(ProtoStructValueTest, StringGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_string"))); - EXPECT_EQ(field.As()->ToString(), ""); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_single_string("foo"); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_string"))); - EXPECT_EQ(field.As()->ToString(), "foo"); + TEST_GET_FIELD( + memory_manager(), FieldId("single_string"), + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), ""); + }, + [](TestAllTypes& message) { message.set_single_string("foo"); }, + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), "foo"); + }); } TEST_P(ProtoStructValueTest, DurationGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_duration"))); - EXPECT_EQ(field.As()->value(), absl::ZeroDuration()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_duration()->set_seconds(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_duration"))); - EXPECT_EQ(field.As()->value(), absl::Seconds(1)); + TEST_GET_FIELD( + memory_manager(), FieldId("single_duration"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), absl::ZeroDuration()); + }, + [](TestAllTypes& message) { + message.mutable_single_duration()->set_seconds(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), absl::Seconds(1)); + }); } TEST_P(ProtoStructValueTest, TimestampGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_timestamp"))); - EXPECT_EQ(field.As()->value(), absl::UnixEpoch()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_timestamp()->set_seconds(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_timestamp"))); - EXPECT_EQ(field.As()->value(), - absl::UnixEpoch() + absl::Seconds(1)); + TEST_GET_FIELD( + memory_manager(), FieldId("single_timestamp"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), absl::UnixEpoch()); + }, + [](TestAllTypes& message) { + message.mutable_single_timestamp()->set_seconds(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), + absl::UnixEpoch() + absl::Seconds(1)); + }); } TEST_P(ProtoStructValueTest, EnumGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("standalone_enum"))); - EXPECT_EQ(field.As()->number(), 0); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.set_standalone_enum(TestAllTypes::BAR); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("standalone_enum"))); - EXPECT_EQ(field.As()->number(), 1); + TEST_GET_FIELD( + memory_manager(), FieldId("standalone_enum"), + [](const Handle& field) { + EXPECT_EQ(field.As()->number(), 0); + }, + [](TestAllTypes& message) { + message.set_standalone_enum(TestAllTypes::BAR); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->number(), 1); + }); } TEST_P(ProtoStructValueTest, MessageGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("standalone_message"))); - EXPECT_THAT(*field.As()->value(), - EqualsProto(CreateTestMessage().standalone_message())); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_standalone_message()->set_bb(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("standalone_message"))); - TestAllTypes::NestedMessage expected = - CreateTestMessage([](TestAllTypes& message) { + TEST_GET_FIELD( + memory_manager(), FieldId("standalone_message"), + [](const Handle& field) { + EXPECT_THAT(*field.As()->value(), + EqualsProto(CreateTestMessage().standalone_message())); + }, + [](TestAllTypes& message) { message.mutable_standalone_message()->set_bb(1); - }).standalone_message(); - TestAllTypes::NestedMessage scratch; - EXPECT_THAT(*field.As()->value(), EqualsProto(expected)); - EXPECT_THAT(*field.As()->value(scratch), - EqualsProto(expected)); - google::protobuf::Arena arena; - EXPECT_THAT(*field.As()->value(arena), - EqualsProto(expected)); -} - -TEST_P(ProtoStructValueTest, BoolWrapperGetField) { - TypeFactory type_factory(memory_manager()); + }, + [](const Handle& field) { + TestAllTypes::NestedMessage expected = + CreateTestMessage([](TestAllTypes& message) { + message.mutable_standalone_message()->set_bb(1); + }).standalone_message(); + TestAllTypes::NestedMessage scratch; + EXPECT_THAT(*field.As()->value(), + EqualsProto(expected)); + EXPECT_THAT(*field.As()->value(scratch), + EqualsProto(expected)); + google::protobuf::Arena arena; + EXPECT_THAT(*field.As()->value(arena), + EqualsProto(expected)); + }); +} + +template +void TestGetWrapperField(MemoryManager& memory_manager, FieldId id, + UnsetFieldTester&& unset_field_tester, + TestMessageMaker&& test_message_maker, + SetFieldTester&& set_field_tester) { + TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ValueFactory value_factory(type_manager); @@ -826,366 +677,207 @@ TEST_P(ProtoStructValueTest, BoolWrapperGetField) { auto field, value_without->GetField(StructValue::GetFieldContext(value_factory) .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_bool_wrapper"))); + id)); EXPECT_TRUE(field->Is()); ASSERT_OK_AND_ASSIGN( - field, - value_without->GetField(StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_bool_wrapper"))); - EXPECT_TRUE(field->Is()); + field, value_without->GetField(StructValue::GetFieldContext(value_factory) + .set_unbox_null_wrapper_types(false), + id)); + ASSERT_NO_FATAL_FAILURE( + (std::forward(unset_field_tester)(field))); ASSERT_OK_AND_ASSIGN( auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_bool_wrapper()->set_value(true); - }))); + ProtoValue::Create(value_factory, + CreateTestMessage(std::forward( + test_message_maker)))); ASSERT_OK_AND_ASSIGN( field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bool_wrapper"))); - EXPECT_TRUE(field.As()->value()); + value_with->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_NO_FATAL_FAILURE( + (std::forward(set_field_tester)(field))); +} + +#define TEST_GET_WRAPPER_FIELD(...) \ + ASSERT_NO_FATAL_FAILURE(TestGetWrapperField(__VA_ARGS__)) + +TEST_P(ProtoStructValueTest, BoolWrapperGetField) { + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_bool_wrapper"), + [](const Handle& field) { + EXPECT_FALSE(field.As()->value()); + }, + [](TestAllTypes& message) { + message.mutable_single_bool_wrapper()->set_value(true); + }, + [](const Handle& field) { + EXPECT_TRUE(field.As()->value()); + }); } TEST_P(ProtoStructValueTest, Int32WrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_int32_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_int32_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_int32_wrapper()->set_value(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int32_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_int32_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_int32_wrapper()->set_value(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Int64WrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_int64_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_int64_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_int64_wrapper()->set_value(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_int64_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_int64_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_int64_wrapper()->set_value(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Uint32WrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_uint32_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_uint32_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_uint32_wrapper()->set_value(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint32_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_uint32_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_uint32_wrapper()->set_value(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, Uint64WrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_uint64_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_uint64_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_uint64_wrapper()->set_value(1); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_uint64_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_uint64_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_uint64_wrapper()->set_value(1); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, FloatWrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_float_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_float_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_float_wrapper()->set_value(1.0); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_float_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_float_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_float_wrapper()->set_value(1.0); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, DoubleWrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_double_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_double_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_double_wrapper()->set_value(1.0); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_double_wrapper"))); - EXPECT_EQ(field.As()->value(), 1); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_double_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 0); + }, + [](TestAllTypes& message) { + message.mutable_single_double_wrapper()->set_value(1.0); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->value(), 1); + }); } TEST_P(ProtoStructValueTest, BytesWrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_bytes_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_bytes_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_bytes_wrapper()->set_value("foo"); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_bytes_wrapper"))); - EXPECT_EQ(field.As()->ToString(), "foo"); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_bytes_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), ""); + }, + [](TestAllTypes& message) { + message.mutable_single_bytes_wrapper()->set_value("foo"); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), "foo"); + }); } TEST_P(ProtoStructValueTest, StringWrapperGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - ProtoStructType::FieldId("single_string_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, - value_without->GetField( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - ProtoStructType::FieldId("single_string_wrapper"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_string_wrapper()->set_value("foo"); - }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_string_wrapper"))); - EXPECT_EQ(field.As()->ToString(), "foo"); + TEST_GET_WRAPPER_FIELD( + memory_manager(), FieldId("single_string_wrapper"), + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), ""); + }, + [](TestAllTypes& message) { + message.mutable_single_string_wrapper()->set_value("foo"); + }, + [](const Handle& field) { + EXPECT_EQ(field.As()->ToString(), "foo"); + }); } TEST_P(ProtoStructValueTest, StructGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_struct"))); - ASSERT_TRUE(field->Is()); - EXPECT_TRUE(field->As().empty()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - google::protobuf::Value value_proto; - value_proto.set_bool_value(true); - message.mutable_single_struct()->mutable_fields()->insert( - {"foo", std::move(value_proto)}); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_struct"))); - ASSERT_TRUE(field->Is()); - EXPECT_EQ(field->As().size(), 1); - ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); - EXPECT_THAT( - field->As().Get(MapValue::GetContext(value_factory), key), - IsOkAndHolds(Optional(ValueOf(value_factory, true)))); + TEST_GET_FIELD( + memory_manager(), FieldId("single_struct"), + [](const Handle& field) { + ASSERT_TRUE(field->Is()); + EXPECT_TRUE(field->As().empty()); + }, + [](TestAllTypes& message) { + google::protobuf::Value value_proto; + value_proto.set_bool_value(true); + message.mutable_single_struct()->mutable_fields()->insert( + {"foo", std::move(value_proto)}); + }, + [](ValueFactory& value_factory, const Handle& field) { + ASSERT_TRUE(field->Is()); + EXPECT_EQ(field->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + EXPECT_THAT( + field->As().Get(MapValue::GetContext(value_factory), key), + IsOkAndHolds(Optional(ValueOf(value_factory, true)))); + }); } TEST_P(ProtoStructValueTest, ListValueGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("list_value"))); - ASSERT_TRUE(field->Is()); - EXPECT_TRUE(field->As().empty()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.mutable_list_value()->add_values()->set_bool_value(true); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("list_value"))); - ASSERT_TRUE(field->Is()); - EXPECT_EQ(field->As().size(), 1); - EXPECT_THAT( - field->As().Get(ListValue::GetContext(value_factory), 0), - IsOkAndHolds(ValueOf(value_factory, true))); + TEST_GET_FIELD( + memory_manager(), FieldId("list_value"), + [](const Handle& field) { + ASSERT_TRUE(field->Is()); + EXPECT_TRUE(field->As().empty()); + }, + [](TestAllTypes& message) { + message.mutable_list_value()->add_values()->set_bool_value(true); + }, + [](ValueFactory& value_factory, const Handle& field) { + ASSERT_TRUE(field->Is()); + EXPECT_EQ(field->As().size(), 1); + EXPECT_THAT( + field->As().Get(ListValue::GetContext(value_factory), 0), + IsOkAndHolds(ValueOf(value_factory, true))); + }); } TEST_P(ProtoStructValueTest, ValueGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_value"))); - EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.mutable_single_value()->set_bool_value(true); - }))); - EXPECT_THAT(value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("single_value")), - IsOkAndHolds(ValueOf(value_factory, true))); + TEST_GET_FIELD( + memory_manager(), FieldId("single_value"), + [](const Handle& field) { EXPECT_TRUE(field->Is()); }, + [](TestAllTypes& message) { + message.mutable_single_value()->set_bool_value(true); + }, + [](const Handle& field) { + EXPECT_TRUE(field->As().value()); + }); } TEST_P(ProtoStructValueTest, NullValueListGetField) { From 0b15d05aec961f8b9cd4d1c7bbe11b44beef7b19 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 11 May 2023 18:49:54 +0000 Subject: [PATCH 260/303] Forbid direct creation of `ConstantId` and `FieldId` PiperOrigin-RevId: 531265509 --- base/internal/type.h | 1 + base/type_test.cc | 81 ++-- base/types/enum_type.cc | 41 ++ base/types/enum_type.h | 60 ++- base/types/struct_type.cc | 4 +- base/types/struct_type.h | 76 ++- base/value_factory.h | 10 +- base/value_test.cc | 101 ++-- base/values/struct_value.cc | 8 +- base/values/struct_value.h | 19 + eval/eval/select_step.cc | 7 +- eval/internal/interop_test.cc | 34 +- extensions/protobuf/enum_type.cc | 13 +- extensions/protobuf/enum_type.h | 3 + extensions/protobuf/enum_type_test.cc | 6 +- extensions/protobuf/struct_type.cc | 24 +- extensions/protobuf/struct_type.h | 3 + extensions/protobuf/struct_type_test.cc | 39 +- extensions/protobuf/struct_value.cc | 17 +- extensions/protobuf/struct_value_test.cc | 567 +++++++++++------------ 20 files changed, 608 insertions(+), 506 deletions(-) diff --git a/base/internal/type.h b/base/internal/type.h index fb73d2125..c9ddbe09c 100644 --- a/base/internal/type.h +++ b/base/internal/type.h @@ -44,6 +44,7 @@ class LegacyListType; class ModernListType; class LegacyMapType; class ModernMapType; +struct FieldIdFactory; template class SimpleType; diff --git a/base/type_test.cc b/base/type_test.cc index 56e244831..fdbdc2453 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -56,13 +56,14 @@ class TestEnumType final : public EnumType { "EnumType::NewConstantIterator is unimplemented"); } - protected: absl::StatusOr> FindConstantByName( absl::string_view name) const override { if (name == "VALUE1") { - return Constant("VALUE1", static_cast(TestEnum::kValue1)); + return Constant(MakeConstantId(TestEnum::kValue1), "VALUE1", + static_cast(TestEnum::kValue1)); } else if (name == "VALUE2") { - return Constant("VALUE2", static_cast(TestEnum::kValue2)); + return Constant(MakeConstantId(TestEnum::kValue2), "VALUE2", + static_cast(TestEnum::kValue2)); } return absl::nullopt; } @@ -71,9 +72,11 @@ class TestEnumType final : public EnumType { int64_t number) const override { switch (number) { case 1: - return Constant("VALUE1", static_cast(TestEnum::kValue1)); + return Constant(MakeConstantId(TestEnum::kValue1), "VALUE1", + static_cast(TestEnum::kValue1)); case 2: - return Constant("VALUE2", static_cast(TestEnum::kValue2)); + return Constant(MakeConstantId(TestEnum::kValue2), "VALUE2", + static_cast(TestEnum::kValue2)); default: return absl::nullopt; } @@ -104,17 +107,19 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { "StructType::NewFieldIterator() is unimplemented"); } - protected: absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const override { if (name == "bool_field") { - return Field("bool_field", 0, type_manager.type_factory().GetBoolType()); + return Field(MakeFieldId(0), "bool_field", 0, + type_manager.type_factory().GetBoolType()); } else if (name == "int_field") { - return Field("int_field", 1, type_manager.type_factory().GetIntType()); + return Field(MakeFieldId(1), "int_field", 1, + type_manager.type_factory().GetIntType()); } else if (name == "uint_field") { - return Field("uint_field", 2, type_manager.type_factory().GetUintType()); + return Field(MakeFieldId(2), "uint_field", 2, + type_manager.type_factory().GetUintType()); } else if (name == "double_field") { - return Field("double_field", 3, + return Field(MakeFieldId(3), "double_field", 3, type_manager.type_factory().GetDoubleType()); } return absl::nullopt; @@ -124,15 +129,16 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { TypeManager& type_manager, int64_t number) const override { switch (number) { case 0: - return Field("bool_field", 0, + return Field(MakeFieldId(0), "bool_field", 0, type_manager.type_factory().GetBoolType()); case 1: - return Field("int_field", 1, type_manager.type_factory().GetIntType()); + return Field(MakeFieldId(1), "int_field", 1, + type_manager.type_factory().GetIntType()); case 2: - return Field("uint_field", 2, + return Field(MakeFieldId(2), "uint_field", 2, type_manager.type_factory().GetUintType()); case 3: - return Field("double_field", 3, + return Field(MakeFieldId(3), "double_field", 3, type_manager.type_factory().GetDoubleType()); default: return absl::nullopt; @@ -502,29 +508,25 @@ TEST_P(EnumTypeTest, FindConstant) { ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); - ASSERT_OK_AND_ASSIGN(auto value1, - enum_type->FindConstant(EnumType::ConstantId("VALUE1"))); + ASSERT_OK_AND_ASSIGN(auto value1, enum_type->FindConstantByName("VALUE1")); EXPECT_EQ(value1->name, "VALUE1"); EXPECT_EQ(value1->number, 1); - ASSERT_OK_AND_ASSIGN(value1, - enum_type->FindConstant(EnumType::ConstantId(1))); + ASSERT_OK_AND_ASSIGN(value1, enum_type->FindConstantByNumber(1)); EXPECT_EQ(value1->name, "VALUE1"); EXPECT_EQ(value1->number, 1); - ASSERT_OK_AND_ASSIGN(auto value2, - enum_type->FindConstant(EnumType::ConstantId("VALUE2"))); + ASSERT_OK_AND_ASSIGN(auto value2, enum_type->FindConstantByName("VALUE2")); EXPECT_EQ(value2->name, "VALUE2"); EXPECT_EQ(value2->number, 2); - ASSERT_OK_AND_ASSIGN(value2, - enum_type->FindConstant(EnumType::ConstantId(2))); + ASSERT_OK_AND_ASSIGN(value2, enum_type->FindConstantByNumber(2)); EXPECT_EQ(value2->name, "VALUE2"); EXPECT_EQ(value2->number, 2); - EXPECT_THAT(enum_type->FindConstant(EnumType::ConstantId("VALUE3")), + EXPECT_THAT(enum_type->FindConstantByName("VALUE3"), IsOkAndHolds(Eq(absl::nullopt))); - EXPECT_THAT(enum_type->FindConstant(EnumType::ConstantId(3)), + EXPECT_THAT(enum_type->FindConstantByNumber(3), IsOkAndHolds(Eq(absl::nullopt))); } @@ -542,61 +544,52 @@ TEST_P(StructTypeTest, FindField) { type_manager.type_factory().CreateStructType()); ASSERT_OK_AND_ASSIGN( - auto field1, - struct_type->FindField(type_manager, StructType::FieldId("bool_field"))); + auto field1, struct_type->FindFieldByName(type_manager, "bool_field")); EXPECT_EQ(field1->name, "bool_field"); EXPECT_EQ(field1->number, 0); EXPECT_EQ(field1->type, type_manager.type_factory().GetBoolType()); - ASSERT_OK_AND_ASSIGN( - field1, struct_type->FindField(type_manager, StructType::FieldId(0))); + ASSERT_OK_AND_ASSIGN(field1, struct_type->FindFieldByNumber(type_manager, 0)); EXPECT_EQ(field1->name, "bool_field"); EXPECT_EQ(field1->number, 0); EXPECT_EQ(field1->type, type_manager.type_factory().GetBoolType()); - ASSERT_OK_AND_ASSIGN( - auto field2, - struct_type->FindField(type_manager, StructType::FieldId("int_field"))); + ASSERT_OK_AND_ASSIGN(auto field2, + struct_type->FindFieldByName(type_manager, "int_field")); EXPECT_EQ(field2->name, "int_field"); EXPECT_EQ(field2->number, 1); EXPECT_EQ(field2->type, type_manager.type_factory().GetIntType()); - ASSERT_OK_AND_ASSIGN( - field2, struct_type->FindField(type_manager, StructType::FieldId(1))); + ASSERT_OK_AND_ASSIGN(field2, struct_type->FindFieldByNumber(type_manager, 1)); EXPECT_EQ(field2->name, "int_field"); EXPECT_EQ(field2->number, 1); EXPECT_EQ(field2->type, type_manager.type_factory().GetIntType()); ASSERT_OK_AND_ASSIGN( - auto field3, - struct_type->FindField(type_manager, StructType::FieldId("uint_field"))); + auto field3, struct_type->FindFieldByName(type_manager, "uint_field")); EXPECT_EQ(field3->name, "uint_field"); EXPECT_EQ(field3->number, 2); EXPECT_EQ(field3->type, type_manager.type_factory().GetUintType()); - ASSERT_OK_AND_ASSIGN( - field3, struct_type->FindField(type_manager, StructType::FieldId(2))); + ASSERT_OK_AND_ASSIGN(field3, struct_type->FindFieldByNumber(type_manager, 2)); EXPECT_EQ(field3->name, "uint_field"); EXPECT_EQ(field3->number, 2); EXPECT_EQ(field3->type, type_manager.type_factory().GetUintType()); ASSERT_OK_AND_ASSIGN( - auto field4, struct_type->FindField(type_manager, - StructType::FieldId("double_field"))); + auto field4, struct_type->FindFieldByName(type_manager, "double_field")); EXPECT_EQ(field4->name, "double_field"); EXPECT_EQ(field4->number, 3); EXPECT_EQ(field4->type, type_manager.type_factory().GetDoubleType()); - ASSERT_OK_AND_ASSIGN( - field4, struct_type->FindField(type_manager, StructType::FieldId(3))); + ASSERT_OK_AND_ASSIGN(field4, struct_type->FindFieldByNumber(type_manager, 3)); EXPECT_EQ(field4->name, "double_field"); EXPECT_EQ(field4->number, 3); EXPECT_EQ(field4->type, type_manager.type_factory().GetDoubleType()); - EXPECT_THAT(struct_type->FindField(type_manager, - StructType::FieldId("missing_field")), + EXPECT_THAT(struct_type->FindFieldByName(type_manager, "missing_field"), IsOkAndHolds(Eq(absl::nullopt))); - EXPECT_THAT(struct_type->FindField(type_manager, StructType::FieldId(4)), + EXPECT_THAT(struct_type->FindFieldByNumber(type_manager, 4), IsOkAndHolds(Eq(absl::nullopt))); } diff --git a/base/types/enum_type.cc b/base/types/enum_type.cc index 0e3a526d0..d04931fc3 100644 --- a/base/types/enum_type.cc +++ b/base/types/enum_type.cc @@ -14,16 +14,57 @@ #include "base/types/enum_type.h" +#include + #include "absl/base/macros.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "internal/overloaded.h" #include "internal/status_macros.h" namespace cel { CEL_INTERNAL_TYPE_IMPL(EnumType); +bool operator<(const EnumType::ConstantId& lhs, + const EnumType::ConstantId& rhs) { + return absl::visit( + internal::Overloaded{ + [&rhs](absl::string_view lhs_name) { + return absl::visit( + internal::Overloaded{// (absl::string_view, absl::string_view) + [lhs_name](absl::string_view rhs_name) { + return lhs_name < rhs_name; + }, + // (absl::string_view, int64_t) + [](int64_t rhs_number) { return false; }}, + rhs.data_); + }, + [&rhs](int64_t lhs_number) { + return absl::visit( + internal::Overloaded{ + // (int64_t, absl::string_view) + [](absl::string_view rhs_name) { return true; }, + // (int64_t, int64_t) + [lhs_number](int64_t rhs_number) { + return lhs_number < rhs_number; + }, + }, + rhs.data_); + }}, + lhs.data_); +} + +std::string EnumType::ConstantId::DebugString() const { + return absl::visit( + internal::Overloaded{ + [](absl::string_view name) { return std::string(name); }, + [](int64_t number) { return absl::StrCat(number); }}, + data_); +} + EnumType::EnumType() : base_internal::HeapData(kKind) { // Ensure `Type*` and `base_internal::HeapData*` are not thunked. ABSL_ASSERT( diff --git a/base/types/enum_type.h b/base/types/enum_type.h index e74cba3e0..7d8e4c314 100644 --- a/base/types/enum_type.h +++ b/base/types/enum_type.h @@ -17,6 +17,7 @@ #include #include +#include #include #include "absl/base/attributes.h" @@ -46,21 +47,41 @@ class EnumType : public Type, public base_internal::HeapData { class ConstantId final { public: - explicit ConstantId(absl::string_view name) - : data_(absl::in_place_type, name) {} - - explicit ConstantId(int64_t number) - : data_(absl::in_place_type, number) {} - ConstantId() = delete; ConstantId(const ConstantId&) = default; + ConstantId(ConstantId&&) = default; ConstantId& operator=(const ConstantId&) = default; + ConstantId& operator=(ConstantId&&) = default; + + std::string DebugString() const; + + friend bool operator==(const ConstantId& lhs, const ConstantId& rhs) { + return lhs.data_ == rhs.data_; + } + + friend bool operator<(const ConstantId& lhs, const ConstantId& rhs); + + template + friend H AbslHashValue(H state, const ConstantId& id) { + return H::combine(std::move(state), id.data_); + } + + template + friend void AbslStringify(S& sink, const ConstantId& id) { + sink.Append(id.DebugString()); + } private: friend class EnumType; friend class EnumValue; + explicit ConstantId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit ConstantId(int64_t number) + : data_(absl::in_place_type, number) {} + absl::variant data_; }; @@ -106,6 +127,25 @@ class EnumType : public Type, public base_internal::HeapData { MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; protected: + static ConstantId MakeConstantId(absl::string_view name) { + return ConstantId(name); + } + + static ConstantId MakeConstantId(int64_t number) { + return ConstantId(number); + } + + template + static std::enable_if_t< + std::conjunction_v< + std::is_enum, + std::is_convertible, int64_t>>, + ConstantId> + MakeConstantId(E e) { + return MakeConstantId( + static_cast(static_cast>(e))); + } + EnumType(); private: @@ -131,9 +171,13 @@ class EnumType : public Type, public base_internal::HeapData { // Constant describes a single value in an enumeration. All fields are valid so // long as EnumType is valid. struct EnumType::Constant final { - Constant(absl::string_view name, int64_t number, const void* hint = nullptr) - : name(name), number(number), hint(hint) {} + Constant(ConstantId id, absl::string_view name, int64_t number, + const void* hint = nullptr) + : id(id), name(name), number(number), hint(hint) {} + // Identifier which allows the most efficient form of lookup, compared to + // looking up by name or number. + ConstantId id; // The unqualified enumeration value name. absl::string_view name; // The enumeration value number. diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc index 9afcef8d2..a06984d00 100644 --- a/base/types/struct_type.cc +++ b/base/types/struct_type.cc @@ -128,8 +128,8 @@ absl::StatusOr> StructType::FindField( absl::StatusOr StructType::FieldIterator::NextId( TypeManager& type_manager) { - CEL_ASSIGN_OR_RETURN(auto name, NextName(type_manager)); - return FieldId(name); + CEL_ASSIGN_OR_RETURN(auto field, Next(type_manager)); + return field.id; } absl::StatusOr StructType::FieldIterator::NextName( diff --git a/base/types/struct_type.h b/base/types/struct_type.h index 1b7b4dddb..d7ce4bca9 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/base/attributes.h" @@ -51,16 +52,12 @@ class StructType : public Type { class FieldId final { public: - explicit FieldId(absl::string_view name) - : data_(absl::in_place_type, name) {} - - explicit FieldId(int64_t number) - : data_(absl::in_place_type, number) {} - FieldId() = delete; FieldId(const FieldId&) = default; + FieldId(FieldId&&) = default; FieldId& operator=(const FieldId&) = default; + FieldId& operator=(FieldId&&) = default; std::string DebugString() const; @@ -83,6 +80,13 @@ class StructType : public Type { private: friend class StructType; friend class StructValue; + friend struct base_internal::FieldIdFactory; + + explicit FieldId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit FieldId(int64_t number) + : data_(absl::in_place_type, number) {} absl::variant data_; }; @@ -129,6 +133,22 @@ class StructType : public Type { absl::StatusOr> NewFieldIterator( MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + protected: + static FieldId MakeFieldId(absl::string_view name) { return FieldId(name); } + + static FieldId MakeFieldId(int64_t number) { return FieldId(number); } + + template + static std::enable_if_t< + std::conjunction_v< + std::is_enum, + std::is_convertible, int64_t>>, + FieldId> + MakeFieldId(E e) { + return MakeFieldId( + static_cast(static_cast>(e))); + } + private: friend internal::TypeInfo base_internal::GetStructTypeTypeId( const StructType& struct_type); @@ -151,10 +171,13 @@ class StructType : public Type { // Field describes a single field in a struct. All fields are valid so long as // StructType is valid, except Field::type which is managed. struct StructType::Field final { - explicit Field(absl::string_view name, int64_t number, Handle type, - const void* hint = nullptr) - : name(name), number(number), type(std::move(type)), hint(hint) {} + explicit Field(FieldId id, absl::string_view name, int64_t number, + Handle type, const void* hint = nullptr) + : id(id), name(name), number(number), type(std::move(type)), hint(hint) {} + // Identifier which allows the most efficient form of lookup, compared to + // looking up by name or number. + FieldId id; // The field name. absl::string_view name; // The field number. @@ -196,11 +219,12 @@ namespace base_internal { ABSL_ATTRIBUTE_WEAK absl::string_view MessageTypeName(uintptr_t msg); ABSL_ATTRIBUTE_WEAK size_t MessageTypeFieldCount(uintptr_t msg); -class LegacyStructType final : public StructType, - public base_internal::InlineData { +class LegacyStructValueFieldIterator; + +class LegacyStructType final : public StructType, public InlineData { public: static bool Is(const Type& type) { - return type.kind() == kKind && + return StructType::Is(type) && static_cast(type).TypeId() == internal::TypeId(); } @@ -234,17 +258,17 @@ class LegacyStructType final : public StructType, private: static constexpr uintptr_t kMetadata = - base_internal::kStoredInline | base_internal::kTrivial | - (static_cast(kKind) << base_internal::kKindShift); + kStoredInline | kTrivial | (static_cast(kKind) << kKindShift); + friend class LegacyStructValueFieldIterator; friend struct interop_internal::LegacyStructTypeAccess; friend class cel::StructType; - friend class base_internal::LegacyStructValue; + friend class LegacyStructValue; template friend class AnyData; explicit LegacyStructType(uintptr_t msg) - : StructType(), base_internal::InlineData(kMetadata), msg_(msg) {} + : StructType(), InlineData(kMetadata), msg_(msg) {} internal::TypeInfo TypeId() const { return internal::TypeId(); @@ -255,10 +279,10 @@ class LegacyStructType final : public StructType, uintptr_t msg_; }; -class AbstractStructType : public StructType, public base_internal::HeapData { +class AbstractStructType : public StructType, public HeapData { public: static bool Is(const Type& type) { - return type.kind() == kKind && + return StructType::Is(type) && static_cast(type).TypeId() != internal::TypeId(); } @@ -293,14 +317,13 @@ class AbstractStructType : public StructType, public base_internal::HeapData { AbstractStructType(); private: - friend internal::TypeInfo base_internal::GetStructTypeTypeId( - const StructType& struct_type); + friend internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); struct FindFieldVisitor; friend struct FindFieldVisitor; friend class MemoryManager; friend class TypeFactory; - friend class base_internal::TypeHandle; + friend class TypeHandle; friend class StructValue; friend class cel::StructType; @@ -343,6 +366,17 @@ CEL_INTERNAL_TYPE_DECL(StructType); namespace base_internal { +// This should be used for testing only, and is private. +struct FieldIdFactory { + static StructType::FieldId Make(absl::string_view name) { + return StructType::FieldId(name); + } + + static StructType::FieldId Make(int64_t number) { + return StructType::FieldId(number); + } +}; + inline internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type) { return struct_type.TypeId(); } diff --git a/base/value_factory.h b/base/value_factory.h index 5bcc170a2..2e02a0d21 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -21,6 +21,7 @@ #include #include "absl/base/attributes.h" +#include "absl/base/optimization.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -247,8 +248,8 @@ class ValueFactory final { const Handle& enum_type, int64_t number) ABSL_ATTRIBUTE_LIFETIME_BOUND { CEL_ASSIGN_OR_RETURN(auto constant, - enum_type->FindConstant(EnumType::ConstantId(number))); - if (!constant.has_value()) { + enum_type->FindConstantByNumber(number)); + if (ABSL_PREDICT_FALSE(!constant.has_value())) { return absl::NotFoundError(absl::StrCat("no such enum number", number)); } return base_internal::HandleFactory::template Make( @@ -258,9 +259,8 @@ class ValueFactory final { absl::StatusOr> CreateEnumValue( const Handle& enum_type, absl::string_view name) ABSL_ATTRIBUTE_LIFETIME_BOUND { - CEL_ASSIGN_OR_RETURN(auto constant, - enum_type->FindConstant(EnumType::ConstantId(name))); - if (!constant.has_value()) { + CEL_ASSIGN_OR_RETURN(auto constant, enum_type->FindConstantByName(name)); + if (ABSL_PREDICT_FALSE(!constant.has_value())) { return absl::NotFoundError(absl::StrCat("no such enum value", name)); } return base_internal::HandleFactory::template Make( diff --git a/base/value_test.cc b/base/value_test.cc index 2c37b7e38..24ae2f316 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -73,10 +73,10 @@ class TestEnumType final : public EnumType { absl::StatusOr> FindConstantByName( absl::string_view name) const override { if (name == "VALUE1") { - return Constant("VALUE1", 1); + return Constant(MakeConstantId(1), "VALUE1", 1); } if (name == "VALUE2") { - return Constant("VALUE2", 2); + return Constant(MakeConstantId(2), "VALUE2", 2); } return absl::nullopt; } @@ -85,9 +85,9 @@ class TestEnumType final : public EnumType { int64_t number) const override { switch (number) { case 1: - return Constant("VALUE1", 1); + return Constant(MakeConstantId(1), "VALUE1", 1); case 2: - return Constant("VALUE2", 2); + return Constant(MakeConstantId(2), "VALUE2", 2); default: return absl::nullopt; } @@ -228,13 +228,16 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { absl::StatusOr> FindFieldByName( TypeManager& type_manager, absl::string_view name) const override { if (name == "bool_field") { - return Field("bool_field", 0, type_manager.type_factory().GetBoolType()); + return Field(MakeFieldId(0), "bool_field", 0, + type_manager.type_factory().GetBoolType()); } else if (name == "int_field") { - return Field("int_field", 1, type_manager.type_factory().GetIntType()); + return Field(MakeFieldId(1), "int_field", 1, + type_manager.type_factory().GetIntType()); } else if (name == "uint_field") { - return Field("uint_field", 2, type_manager.type_factory().GetUintType()); + return Field(MakeFieldId(2), "uint_field", 2, + type_manager.type_factory().GetUintType()); } else if (name == "double_field") { - return Field("double_field", 3, + return Field(MakeFieldId(3), "double_field", 3, type_manager.type_factory().GetDoubleType()); } return absl::nullopt; @@ -244,15 +247,16 @@ class TestStructType final : public CEL_STRUCT_TYPE_CLASS { TypeManager& type_manager, int64_t number) const override { switch (number) { case 0: - return Field("bool_field", 0, + return Field(MakeFieldId(0), "bool_field", 0, type_manager.type_factory().GetBoolType()); case 1: - return Field("int_field", 1, type_manager.type_factory().GetIntType()); + return Field(MakeFieldId(1), "int_field", 1, + type_manager.type_factory().GetIntType()); case 2: - return Field("uint_field", 2, + return Field(MakeFieldId(2), "uint_field", 2, type_manager.type_factory().GetUintType()); case 3: - return Field("double_field", 3, + return Field(MakeFieldId(3), "double_field", 3, type_manager.type_factory().GetDoubleType()); default: return absl::nullopt; @@ -2104,33 +2108,27 @@ TEST_P(StructValueTest, GetField) { auto struct_value, value_factory.CreateStructValue(struct_type)); StructValue::GetFieldContext context(value_factory); - EXPECT_THAT( - struct_value->GetField(context, StructValue::FieldId("bool_field")), - IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); - EXPECT_THAT(struct_value->GetField(context, StructValue::FieldId(0)), + EXPECT_THAT(struct_value->GetFieldByName(context, "bool_field"), IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); - EXPECT_THAT( - struct_value->GetField(context, StructValue::FieldId("int_field")), - IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); - EXPECT_THAT(struct_value->GetField(context, StructValue::FieldId(1)), + EXPECT_THAT(struct_value->GetFieldByNumber(context, 0), + IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); + EXPECT_THAT(struct_value->GetFieldByName(context, "int_field"), + IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); + EXPECT_THAT(struct_value->GetFieldByNumber(context, 1), IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); - EXPECT_THAT( - struct_value->GetField(context, StructValue::FieldId("uint_field")), - IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); - EXPECT_THAT(struct_value->GetField(context, StructValue::FieldId(2)), + EXPECT_THAT(struct_value->GetFieldByName(context, "uint_field"), IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); - EXPECT_THAT( - struct_value->GetField(context, StructValue::FieldId("double_field")), - IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); - EXPECT_THAT(struct_value->GetField(context, StructValue::FieldId(3)), + EXPECT_THAT(struct_value->GetFieldByNumber(context, 2), + IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); + EXPECT_THAT(struct_value->GetFieldByName(context, "double_field"), + IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); + EXPECT_THAT(struct_value->GetFieldByNumber(context, 3), IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); - EXPECT_THAT( - struct_value->GetField(context, StructValue::FieldId("missing_field")), - StatusIs(absl::StatusCode::kNotFound)); - EXPECT_THAT( - struct_value->HasField(StructValue::HasFieldContext((type_manager)), - StructValue::FieldId(4)), - StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->GetFieldByName(context, "missing_field"), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->HasFieldByNumber( + StructValue::HasFieldContext((type_manager)), 4), + StatusIs(absl::StatusCode::kNotFound)); } TEST_P(StructValueTest, HasField) { @@ -2143,30 +2141,21 @@ TEST_P(StructValueTest, HasField) { auto struct_value, value_factory.CreateStructValue(struct_type)); StructValue::HasFieldContext context(type_manager); - EXPECT_THAT( - struct_value->HasField(context, StructValue::FieldId("bool_field")), - IsOkAndHolds(true)); - EXPECT_THAT(struct_value->HasField(context, StructValue::FieldId(0)), + EXPECT_THAT(struct_value->HasFieldByName(context, "bool_field"), IsOkAndHolds(true)); - EXPECT_THAT( - struct_value->HasField(context, StructValue::FieldId("int_field")), - IsOkAndHolds(true)); - EXPECT_THAT(struct_value->HasField(context, StructValue::FieldId(1)), + EXPECT_THAT(struct_value->HasFieldByNumber(context, 0), IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasFieldByName(context, "int_field"), IsOkAndHolds(true)); - EXPECT_THAT( - struct_value->HasField(context, StructValue::FieldId("uint_field")), - IsOkAndHolds(true)); - EXPECT_THAT(struct_value->HasField(context, StructValue::FieldId(2)), + EXPECT_THAT(struct_value->HasFieldByNumber(context, 1), IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasFieldByName(context, "uint_field"), IsOkAndHolds(true)); - EXPECT_THAT( - struct_value->HasField(context, StructValue::FieldId("double_field")), - IsOkAndHolds(true)); - EXPECT_THAT(struct_value->HasField(context, StructValue::FieldId(3)), + EXPECT_THAT(struct_value->HasFieldByNumber(context, 2), IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasFieldByName(context, "double_field"), IsOkAndHolds(true)); - EXPECT_THAT( - struct_value->HasField(context, StructValue::FieldId("missing_field")), - StatusIs(absl::StatusCode::kNotFound)); - EXPECT_THAT(struct_value->HasField(context, StructValue::FieldId(4)), + EXPECT_THAT(struct_value->HasFieldByNumber(context, 3), IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasFieldByName(context, "missing_field"), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->HasFieldByNumber(context, 4), StatusIs(absl::StatusCode::kNotFound)); } @@ -2475,8 +2464,6 @@ TEST(EnumValueTest, UnknownConstantDebugString) { ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); EXPECT_EQ(EnumValue::DebugString(*enum_type, 3), "test_enum.TestEnum(3)"); - EXPECT_EQ(EnumValue::DebugString(*enum_type, EnumType::Constant("", 3)), - "test_enum.TestEnum(3)"); } Handle DefaultNullValue(ValueFactory& value_factory) { diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc index a41693004..7c9a5dbfb 100644 --- a/base/values/struct_value.cc +++ b/base/values/struct_value.cc @@ -136,8 +136,6 @@ absl::StatusOr> StructValue::FieldIterator::NextValue( namespace base_internal { -namespace { - class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { public: LegacyStructValueFieldIterator(uintptr_t msg, uintptr_t type_info) @@ -160,7 +158,7 @@ class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { msg_, type_info_, context.value_factory(), field_name, context.unbox_null_wrapper_types())); ++index_; - return Field(StructValue::FieldId(field_name), std::move(value)); + return Field(LegacyStructType::MakeFieldId(field_name), std::move(value)); } absl::StatusOr NextId( @@ -170,7 +168,7 @@ class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { "StructValue::FieldIterator::Next() called when " "StructValue::FieldIterator::HasNext() returns false"); } - return StructValue::FieldId(field_names_[index_++]); + return LegacyStructType::MakeFieldId(field_names_[index_++]); } private: @@ -180,8 +178,6 @@ class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { size_t index_ = 0; }; -} // namespace - Handle LegacyStructValue::type() const { if ((msg_ & kMessageWrapperTagMask) == kMessageWrapperTagMask) { // google::protobuf::Message diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 31f56f9af..af3ba91bb 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -130,6 +130,25 @@ class StructValue : public Value { Handle value; }; + protected: + static FieldId MakeFieldId(absl::string_view name) { + return StructType::MakeFieldId(name); + } + + static FieldId MakeFieldId(int64_t number) { + return StructType::MakeFieldId(number); + } + + template + static std::enable_if_t< + std::conjunction_v< + std::is_enum, + std::is_convertible, int64_t>>, + FieldId> + MakeFieldId(E e) { + return StructType::MakeFieldId(e); + } + private: struct GetFieldVisitor; struct HasFieldVisitor; diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 3d835f0eb..08635e8a4 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -88,11 +88,11 @@ class SelectStep : public ExpressionStepBase { absl::StatusOr> SelectStep::CreateValueFromField( const Handle& msg, ExecutionFrame* frame) const { - return msg->GetField( + return msg->GetFieldByName( StructValue::GetFieldContext(frame->value_factory()) .set_unbox_null_wrapper_types(unboxing_option_ == ProtoWrapperTypeOptions::kUnsetNull), - StructValue::FieldId(field_)); + field_); } absl::optional> CheckForMarkedAttributes( @@ -130,11 +130,10 @@ Handle TestOnlySelect(const Handle& msg, const std::string& field, cel::MemoryManager& memory_manager, cel::TypeManager& type_manager) { - StructValue::FieldId field_id(field); Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); absl::StatusOr result = - msg->HasField(StructValue::HasFieldContext(type_manager), field_id); + msg->HasFieldByName(StructValue::HasFieldContext(type_manager), field); if (!result.ok()) { return CreateErrorValueFromView( diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index 3f65416c3..f7747c8c5 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -26,6 +26,7 @@ #include "absl/strings/escaping.h" #include "absl/time/time.h" #include "base/memory.h" +#include "base/type.h" #include "base/type_manager.h" #include "base/value.h" #include "base/value_factory.h" @@ -737,24 +738,21 @@ TEST(ValueInterop, StructFromLegacy) { EXPECT_EQ(value->kind(), Kind::kStruct); EXPECT_EQ(value->type()->kind(), Kind::kStruct); EXPECT_EQ(value->type()->name(), "google.protobuf.Api"); - EXPECT_THAT(value.As()->HasField( - StructValue::HasFieldContext(type_manager), - StructValue::FieldId("name")), + EXPECT_THAT(value.As()->HasFieldByName( + StructValue::HasFieldContext(type_manager), "name"), IsOkAndHolds(Eq(true))); - EXPECT_THAT( - value.As()->HasField( - StructValue::HasFieldContext(type_manager), StructValue::FieldId(1)), - StatusIs(absl::StatusCode::kUnimplemented)); - ASSERT_OK_AND_ASSIGN(auto value_name_field, - value.As()->GetField( - StructValue::GetFieldContext(value_factory), - StructValue::FieldId("name"))); + EXPECT_THAT(value.As()->HasFieldByNumber( + StructValue::HasFieldContext(type_manager), 1), + StatusIs(absl::StatusCode::kUnimplemented)); + ASSERT_OK_AND_ASSIGN( + auto value_name_field, + value.As()->GetFieldByName( + StructValue::GetFieldContext(value_factory), "name")); ASSERT_TRUE(value_name_field->Is()); EXPECT_EQ(value_name_field.As()->ToString(), "foo"); - EXPECT_THAT( - value.As()->GetField( - StructValue::GetFieldContext(value_factory), StructValue::FieldId(1)), - StatusIs(absl::StatusCode::kUnimplemented)); + EXPECT_THAT(value.As()->GetFieldByNumber( + StructValue::GetFieldContext(value_factory), 1), + StatusIs(absl::StatusCode::kUnimplemented)); auto value_wrapper = LegacyStructValueAccess::ToMessageWrapper( *value.As()); auto legacy_value_wrapper = legacy_value.MessageWrapperOrDie(); @@ -830,6 +828,8 @@ TEST(ValueInterop, LegacyStructEquality) { EXPECT_EQ(lhs_value, rhs_value); } +using ::cel::base_internal::FieldIdFactory; + TEST(ValueInterop, LegacyStructNewFieldIteratorIds) { google::protobuf::Arena arena; extensions::ProtoMemoryManager memory_manager(&arena); @@ -853,8 +853,8 @@ TEST(ValueInterop, LegacyStructNewFieldIteratorIds) { } EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), StatusIs(absl::StatusCode::kFailedPrecondition)); - std::set expected_ids = {StructType::FieldId("name"), - StructType::FieldId("version")}; + std::set expected_ids = { + FieldIdFactory::Make("name"), FieldIdFactory::Make("version")}; EXPECT_EQ(actual_ids, expected_ids); } diff --git a/extensions/protobuf/enum_type.cc b/extensions/protobuf/enum_type.cc index 56de20399..8d801418e 100644 --- a/extensions/protobuf/enum_type.cc +++ b/extensions/protobuf/enum_type.cc @@ -25,8 +25,6 @@ namespace cel::extensions { -namespace { - class ProtoEnumTypeConstantIterator final : public EnumType::ConstantIterator { public: explicit ProtoEnumTypeConstantIterator( @@ -42,7 +40,8 @@ class ProtoEnumTypeConstantIterator final : public EnumType::ConstantIterator { "EnumType::ConstantIterator::HasNext() returns false"); } const auto* value = descriptor_.value(index_++); - return Constant(value->name(), value->number(), value); + return Constant(ProtoEnumType::MakeConstantId(value->number()), + value->name(), value->number(), value); } private: @@ -50,8 +49,6 @@ class ProtoEnumTypeConstantIterator final : public EnumType::ConstantIterator { int index_ = 0; }; -} // namespace - absl::StatusOr> ProtoEnumType::Resolve( TypeManager& type_manager, const google::protobuf::EnumDescriptor& descriptor) { CEL_ASSIGN_OR_RETURN(auto type, @@ -80,7 +77,8 @@ ProtoEnumType::FindConstantByName(absl::string_view name) const { return absl::nullopt; } ABSL_ASSERT(value_desc->name() == name); - return Constant{value_desc->name(), value_desc->number(), value_desc}; + return Constant(MakeConstantId(value_desc->number()), value_desc->name(), + value_desc->number(), value_desc); } absl::StatusOr> @@ -96,7 +94,8 @@ ProtoEnumType::FindConstantByNumber(int64_t number) const { return absl::nullopt; } ABSL_ASSERT(value_desc->number() == number); - return Constant{value_desc->name(), value_desc->number(), value_desc}; + return Constant(MakeConstantId(value_desc->number()), value_desc->name(), + value_desc->number(), value_desc); } absl::StatusOr> diff --git a/extensions/protobuf/enum_type.h b/extensions/protobuf/enum_type.h index 47abb2781..b4772e1fd 100644 --- a/extensions/protobuf/enum_type.h +++ b/extensions/protobuf/enum_type.h @@ -29,6 +29,8 @@ namespace cel::extensions { class ProtoType; class ProtoTypeProvider; +class ProtoEnumTypeConstantIterator; + class ProtoEnumType final : public EnumType { public: static bool Is(const Type& type) { @@ -62,6 +64,7 @@ class ProtoEnumType final : public EnumType { const google::protobuf::EnumDescriptor& descriptor() const { return *descriptor_; } private: + friend class ProtoEnumTypeConstantIterator; friend class ProtoType; friend class ProtoTypeProvider; friend class cel::MemoryManager; diff --git a/extensions/protobuf/enum_type_test.cc b/extensions/protobuf/enum_type_test.cc index eee98dcc4..82edee0ac 100644 --- a/extensions/protobuf/enum_type_test.cc +++ b/extensions/protobuf/enum_type_test.cc @@ -75,8 +75,7 @@ TEST_P(ProtoEnumTypeTest, FindConstantByName) { ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN(auto constant, - type->FindConstant(EnumType::ConstantId("TYPE_STRING"))); + ASSERT_OK_AND_ASSIGN(auto constant, type->FindConstantByName("TYPE_STRING")); ASSERT_TRUE(constant.has_value()); EXPECT_EQ(constant->number, 9); EXPECT_EQ(constant->name, "TYPE_STRING"); @@ -89,8 +88,7 @@ TEST_P(ProtoEnumTypeTest, FindConstantByNumber) { ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN(auto constant, - type->FindConstant(EnumType::ConstantId(9))); + ASSERT_OK_AND_ASSIGN(auto constant, type->FindConstantByNumber(9)); ASSERT_TRUE(constant.has_value()); EXPECT_EQ(constant->number, 9); EXPECT_EQ(constant->name, "TYPE_STRING"); diff --git a/extensions/protobuf/struct_type.cc b/extensions/protobuf/struct_type.cc index 8adc6c12a..cc4a1f808 100644 --- a/extensions/protobuf/struct_type.cc +++ b/extensions/protobuf/struct_type.cc @@ -123,8 +123,6 @@ absl::StatusOr> FieldDescriptorToType( } // namespace -namespace { - class ProtoStructTypeFieldIterator final : public StructType::FieldIterator { public: explicit ProtoStructTypeFieldIterator(const google::protobuf::Descriptor& descriptor) @@ -141,10 +139,20 @@ class ProtoStructTypeFieldIterator final : public StructType::FieldIterator { const auto* field = descriptor_.field(index_); CEL_ASSIGN_OR_RETURN(auto type, FieldDescriptorToType(type_manager, field)); ++index_; - return StructType::Field(field->name(), field->number(), std::move(type), + return StructType::Field(ProtoStructType::MakeFieldId(field->number()), + field->name(), field->number(), std::move(type), field); } + absl::StatusOr NextId(TypeManager& type_manager) override { + if (ABSL_PREDICT_FALSE(index_ >= descriptor_.field_count())) { + return absl::FailedPreconditionError( + "StructType::FieldIterator::Next() called when " + "StructType::FieldIterator::HasNext() returns false"); + } + return ProtoStructType::MakeFieldId(descriptor_.field(index_++)->number()); + } + absl::StatusOr NextName( TypeManager& type_manager) override { if (ABSL_PREDICT_FALSE(index_ >= descriptor_.field_count())) { @@ -169,8 +177,6 @@ class ProtoStructTypeFieldIterator final : public StructType::FieldIterator { int index_ = 0; }; -} // namespace - size_t ProtoStructType::field_count() const { return descriptor().field_count(); } @@ -189,8 +195,8 @@ ProtoStructType::FindFieldByName(TypeManager& type_manager, } CEL_ASSIGN_OR_RETURN(auto type, FieldDescriptorToType(type_manager, field_desc)); - return Field{field_desc->name(), field_desc->number(), std::move(type), - field_desc}; + return Field(MakeFieldId(field_desc->number()), field_desc->name(), + field_desc->number(), std::move(type), field_desc); } absl::StatusOr> @@ -208,8 +214,8 @@ ProtoStructType::FindFieldByNumber(TypeManager& type_manager, } CEL_ASSIGN_OR_RETURN(auto type, FieldDescriptorToType(type_manager, field_desc)); - return Field{field_desc->name(), field_desc->number(), std::move(type), - field_desc}; + return Field(MakeFieldId(field_desc->number()), field_desc->name(), + field_desc->number(), std::move(type), field_desc); } } // namespace cel::extensions diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index e4a95a41e..852e77836 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -37,6 +37,8 @@ namespace protobuf_internal { class ParsedProtoStructValue; } +class ProtoStructTypeFieldIterator; + class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { public: static bool Is(const Type& type) { @@ -71,6 +73,7 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { const google::protobuf::Descriptor& descriptor() const { return *descriptor_; } private: + friend class ProtoStructTypeFieldIterator; friend class ProtoType; friend class ProtoTypeProvider; friend class ProtoStructValue; diff --git a/extensions/protobuf/struct_type_test.cc b/extensions/protobuf/struct_type_test.cc index 27f82c4dd..4499b656d 100644 --- a/extensions/protobuf/struct_type_test.cc +++ b/extensions/protobuf/struct_type_test.cc @@ -75,9 +75,8 @@ TEST_P(ProtoStructTypeTest, FindFieldByName) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, - type->FindField(type_manager, StructType::FieldId("default_value"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "default_value")); ASSERT_TRUE(field.has_value()); EXPECT_EQ(field->number, 11); EXPECT_EQ(field->name, "default_value"); @@ -90,8 +89,7 @@ TEST_P(ProtoStructTypeTest, FindFieldByNumber) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN(auto field, - type->FindField(type_manager, StructType::FieldId(11))); + ASSERT_OK_AND_ASSIGN(auto field, type->FindFieldByNumber(type_manager, 11)); ASSERT_TRUE(field.has_value()); EXPECT_EQ(field->number, 11); EXPECT_EQ(field->name, "default_value"); @@ -104,9 +102,8 @@ TEST_P(ProtoStructTypeTest, EnumField) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, - type->FindField(type_manager, StructType::FieldId("cardinality"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "cardinality")); ASSERT_TRUE(field.has_value()); EXPECT_TRUE(field->type->Is()); EXPECT_EQ(field->type->name(), "google.protobuf.Field.Cardinality"); @@ -118,8 +115,8 @@ TEST_P(ProtoStructTypeTest, BoolField) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, type->FindField(type_manager, StructType::FieldId("packed"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "packed")); ASSERT_TRUE(field.has_value()); EXPECT_EQ(field->type, type_factory.GetBoolType()); } @@ -130,9 +127,8 @@ TEST_P(ProtoStructTypeTest, IntField) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, - type->FindField(type_manager, StructType::FieldId("oneof_index"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "oneof_index")); ASSERT_TRUE(field.has_value()); EXPECT_EQ(field->type, type_factory.GetIntType()); } @@ -143,8 +139,8 @@ TEST_P(ProtoStructTypeTest, StringListField) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, type->FindField(type_manager, StructType::FieldId("oneofs"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "oneofs")); ASSERT_TRUE(field.has_value()); EXPECT_TRUE(field->type->Is()); EXPECT_EQ(field->type.As()->element(), @@ -157,9 +153,8 @@ TEST_P(ProtoStructTypeTest, StructListField) { TypeManager type_manager(type_factory, type_provider); ASSERT_OK_AND_ASSIGN( auto type, ProtoType::Resolve(type_manager)); - ASSERT_OK_AND_ASSIGN( - auto field, - type->FindField(type_manager, StructType::FieldId("options"))); + ASSERT_OK_AND_ASSIGN(auto field, + type->FindFieldByName(type_manager, "options")); ASSERT_TRUE(field.has_value()); EXPECT_TRUE(field->type->Is()); EXPECT_EQ(field->type.As()->element()->name(), @@ -173,14 +168,15 @@ TEST_P(ProtoStructTypeTest, MapField) { ASSERT_OK_AND_ASSIGN(auto type, ProtoType::Resolve(type_manager)); ASSERT_OK_AND_ASSIGN( - auto field, - type->FindField(type_manager, StructType::FieldId("map_string_string"))); + auto field, type->FindFieldByName(type_manager, "map_string_string")); ASSERT_TRUE(field.has_value()); EXPECT_TRUE(field->type->Is()); EXPECT_EQ(field->type.As()->key(), type_factory.GetStringType()); EXPECT_EQ(field->type.As()->value(), type_factory.GetStringType()); } +using ::cel::base_internal::FieldIdFactory; + TEST_P(ProtoStructTypeTest, NewFieldIteratorIds) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; @@ -198,7 +194,8 @@ TEST_P(ProtoStructTypeTest, NewFieldIteratorIds) { std::set expected_ids; const auto* const descriptor = TestAllTypes::descriptor(); for (int index = 0; index < descriptor->field_count(); ++index) { - expected_ids.insert(StructType::FieldId(descriptor->field(index)->name())); + expected_ids.insert( + FieldIdFactory::Make(descriptor->field(index)->number())); } EXPECT_EQ(actual_ids, expected_ids); } diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index e09590bfa..8ead55cc2 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -2211,7 +2211,7 @@ absl::StatusOr> ParsedProtoStructValue::GetFieldByName( const GetFieldContext& context, absl::string_view name) const { CEL_ASSIGN_OR_RETURN( auto field_type, - type()->FindField(context.value_factory().type_manager(), FieldId(name))); + type()->FindFieldByName(context.value_factory().type_manager(), name)); if (ABSL_PREDICT_FALSE(!field_type)) { return interop_internal::CreateNoSuchFieldError(name); } @@ -2221,8 +2221,8 @@ absl::StatusOr> ParsedProtoStructValue::GetFieldByName( absl::StatusOr> ParsedProtoStructValue::GetFieldByNumber( const GetFieldContext& context, int64_t number) const { CEL_ASSIGN_OR_RETURN(auto field_type, - type()->FindField(context.value_factory().type_manager(), - FieldId(number))); + type()->FindFieldByNumber( + context.value_factory().type_manager(), number)); if (ABSL_PREDICT_FALSE(!field_type)) { return interop_internal::CreateNoSuchFieldError(absl::StrCat(number)); } @@ -2619,8 +2619,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( absl::StatusOr ParsedProtoStructValue::HasFieldByName( const HasFieldContext& context, absl::string_view name) const { - CEL_ASSIGN_OR_RETURN( - auto field, type()->FindField(context.type_manager(), FieldId(name))); + CEL_ASSIGN_OR_RETURN(auto field, + type()->FindFieldByName(context.type_manager(), name)); if (ABSL_PREDICT_FALSE(!field.has_value())) { return interop_internal::CreateNoSuchFieldError(name); } @@ -2630,7 +2630,7 @@ absl::StatusOr ParsedProtoStructValue::HasFieldByName( absl::StatusOr ParsedProtoStructValue::HasFieldByNumber( const HasFieldContext& context, int64_t number) const { CEL_ASSIGN_OR_RETURN( - auto field, type()->FindField(context.type_manager(), FieldId(number))); + auto field, type()->FindFieldByNumber(context.type_manager(), number)); if (ABSL_PREDICT_FALSE(!field.has_value())) { return interop_internal::CreateNoSuchFieldError(absl::StrCat(number)); } @@ -2675,7 +2675,8 @@ class ParsedProtoStructValueFieldIterator final CEL_ASSIGN_OR_RETURN(auto value, value_->GetField(context, std::move(type).value())); ++index_; - return Field(StructValue::FieldId(field->name()), std::move(value)); + return Field(ParsedProtoStructValue::MakeFieldId(field->number()), + std::move(value)); } absl::StatusOr NextId( @@ -2685,7 +2686,7 @@ class ParsedProtoStructValueFieldIterator final "StructValue::FieldIterator::Next() called when " "StructValue::FieldIterator::HasNext() returns false"); } - return StructValue::FieldId(fields_[index_++]->name()); + return ParsedProtoStructValue::MakeFieldId(fields_[index_++]->number()); } private: diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index 951873fe3..c7cdff05b 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -93,7 +93,7 @@ T Must(absl::StatusOr status_or) { } template -void TestHasField(MemoryManager& memory_manager, ProtoStructType::FieldId id, +void TestHasField(MemoryManager& memory_manager, absl::string_view name, TestMessageMaker&& test_message_maker, bool found = true) { TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; @@ -101,17 +101,17 @@ void TestHasField(MemoryManager& memory_manager, ProtoStructType::FieldId id, ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), id), - IsOkAndHolds(Eq(false))); + EXPECT_THAT(value_without->HasFieldByName( + StructValue::HasFieldContext(type_manager), name), + IsOkAndHolds(Eq(false))); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, CreateTestMessage(std::forward( test_message_maker)))); - EXPECT_THAT( - value_with->HasField(StructValue::HasFieldContext(type_manager), id), - IsOkAndHolds(Eq(found))); + EXPECT_THAT(value_with->HasFieldByName( + StructValue::HasFieldContext(type_manager), name), + IsOkAndHolds(Eq(found))); } #define TEST_HAS_FIELD(...) ASSERT_NO_FATAL_FAILURE(TestHasField(__VA_ARGS__)) @@ -120,78 +120,78 @@ TEST_P(ProtoStructValueTest, NullValueHasField) { // In proto3, this can never be present as it will always be the default // value. We would need to add `optional` for it to work. TEST_HAS_FIELD( - memory_manager(), FieldId("null_value"), + memory_manager(), "null_value", [](TestAllTypes& message) { message.set_null_value(NULL_VALUE); }, false); } TEST_P(ProtoStructValueTest, OptionalNullValueHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("optional_null_value"), + TEST_HAS_FIELD(memory_manager(), "optional_null_value", [](TestAllTypes& message) { message.set_optional_null_value(NULL_VALUE); }); } TEST_P(ProtoStructValueTest, BoolHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_bool"), + TEST_HAS_FIELD(memory_manager(), "single_bool", [](TestAllTypes& message) { message.set_single_bool(true); }); } TEST_P(ProtoStructValueTest, Int32HasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_int32"), + TEST_HAS_FIELD(memory_manager(), "single_int32", [](TestAllTypes& message) { message.set_single_int32(1); }); } TEST_P(ProtoStructValueTest, Int64HasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_int64"), + TEST_HAS_FIELD(memory_manager(), "single_int64", [](TestAllTypes& message) { message.set_single_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32HasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_uint32"), + TEST_HAS_FIELD(memory_manager(), "single_uint32", [](TestAllTypes& message) { message.set_single_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64HasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_uint64"), + TEST_HAS_FIELD(memory_manager(), "single_uint64", [](TestAllTypes& message) { message.set_single_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_float"), + TEST_HAS_FIELD(memory_manager(), "single_float", [](TestAllTypes& message) { message.set_single_float(1.0); }); } TEST_P(ProtoStructValueTest, DoubleHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_double"), + TEST_HAS_FIELD(memory_manager(), "single_double", [](TestAllTypes& message) { message.set_single_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("single_bytes"), - [](TestAllTypes& message) { message.set_single_bytes("foo"); }); + TEST_HAS_FIELD(memory_manager(), "single_bytes", [](TestAllTypes& message) { + message.set_single_bytes("foo"); + }); } TEST_P(ProtoStructValueTest, StringHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("single_string"), - [](TestAllTypes& message) { message.set_single_string("foo"); }); + TEST_HAS_FIELD(memory_manager(), "single_string", [](TestAllTypes& message) { + message.set_single_string("foo"); + }); } TEST_P(ProtoStructValueTest, DurationHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_duration"), + memory_manager(), "single_duration", [](TestAllTypes& message) { message.mutable_single_duration(); }); } TEST_P(ProtoStructValueTest, TimestampHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_timestamp"), + memory_manager(), "single_timestamp", [](TestAllTypes& message) { message.mutable_single_timestamp(); }); } TEST_P(ProtoStructValueTest, EnumHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("standalone_enum"), + TEST_HAS_FIELD(memory_manager(), "standalone_enum", [](TestAllTypes& message) { message.set_standalone_enum(TestAllTypes::BAR); }); @@ -199,154 +199,154 @@ TEST_P(ProtoStructValueTest, EnumHasField) { TEST_P(ProtoStructValueTest, MessageHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("standalone_message"), + memory_manager(), "standalone_message", [](TestAllTypes& message) { message.mutable_standalone_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_bool_wrapper"), + memory_manager(), "single_bool_wrapper", [](TestAllTypes& message) { message.mutable_single_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_int32_wrapper"), + memory_manager(), "single_int32_wrapper", [](TestAllTypes& message) { message.mutable_single_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_int64_wrapper"), + memory_manager(), "single_int64_wrapper", [](TestAllTypes& message) { message.mutable_single_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt32WrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_uint32_wrapper"), + memory_manager(), "single_uint32_wrapper", [](TestAllTypes& message) { message.mutable_single_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, UInt64WrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_uint64_wrapper"), + memory_manager(), "single_uint64_wrapper", [](TestAllTypes& message) { message.mutable_single_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_float_wrapper"), + memory_manager(), "single_float_wrapper", [](TestAllTypes& message) { message.mutable_single_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_double_wrapper"), + memory_manager(), "single_double_wrapper", [](TestAllTypes& message) { message.mutable_single_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_bytes_wrapper"), + memory_manager(), "single_bytes_wrapper", [](TestAllTypes& message) { message.mutable_single_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("single_string_wrapper"), + memory_manager(), "single_string_wrapper", [](TestAllTypes& message) { message.mutable_single_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("list_value"), + TEST_HAS_FIELD(memory_manager(), "list_value", [](TestAllTypes& message) { message.mutable_list_value(); }); } TEST_P(ProtoStructValueTest, StructHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("single_struct"), - [](TestAllTypes& message) { message.mutable_single_struct(); }); + TEST_HAS_FIELD(memory_manager(), "single_struct", [](TestAllTypes& message) { + message.mutable_single_struct(); + }); } TEST_P(ProtoStructValueTest, ValueHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("single_value"), + TEST_HAS_FIELD(memory_manager(), "single_value", [](TestAllTypes& message) { message.mutable_single_value(); }); } TEST_P(ProtoStructValueTest, NullValueListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_null_value"), + TEST_HAS_FIELD(memory_manager(), "repeated_null_value", [](TestAllTypes& message) { message.add_repeated_null_value(NULL_VALUE); }); } TEST_P(ProtoStructValueTest, BoolListHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_bool"), - [](TestAllTypes& message) { message.add_repeated_bool(true); }); + TEST_HAS_FIELD(memory_manager(), "repeated_bool", [](TestAllTypes& message) { + message.add_repeated_bool(true); + }); } TEST_P(ProtoStructValueTest, Int32ListHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_int32"), - [](TestAllTypes& message) { message.add_repeated_int32(true); }); + TEST_HAS_FIELD(memory_manager(), "repeated_int32", [](TestAllTypes& message) { + message.add_repeated_int32(true); + }); } TEST_P(ProtoStructValueTest, Int64ListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_int64"), + TEST_HAS_FIELD(memory_manager(), "repeated_int64", [](TestAllTypes& message) { message.add_repeated_int64(1); }); } TEST_P(ProtoStructValueTest, Uint32ListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_uint32"), + TEST_HAS_FIELD(memory_manager(), "repeated_uint32", [](TestAllTypes& message) { message.add_repeated_uint32(1); }); } TEST_P(ProtoStructValueTest, Uint64ListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_uint64"), + TEST_HAS_FIELD(memory_manager(), "repeated_uint64", [](TestAllTypes& message) { message.add_repeated_uint64(1); }); } TEST_P(ProtoStructValueTest, FloatListHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_float"), - [](TestAllTypes& message) { message.add_repeated_float(1.0); }); + TEST_HAS_FIELD(memory_manager(), "repeated_float", [](TestAllTypes& message) { + message.add_repeated_float(1.0); + }); } TEST_P(ProtoStructValueTest, DoubleListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_double"), + memory_manager(), "repeated_double", [](TestAllTypes& message) { message.add_repeated_double(1.0); }); } TEST_P(ProtoStructValueTest, BytesListHasField) { - TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_bytes"), - [](TestAllTypes& message) { message.add_repeated_bytes("foo"); }); + TEST_HAS_FIELD(memory_manager(), "repeated_bytes", [](TestAllTypes& message) { + message.add_repeated_bytes("foo"); + }); } TEST_P(ProtoStructValueTest, StringListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_string"), + memory_manager(), "repeated_string", [](TestAllTypes& message) { message.add_repeated_string("foo"); }); } TEST_P(ProtoStructValueTest, DurationListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_duration"), + TEST_HAS_FIELD(memory_manager(), "repeated_duration", [](TestAllTypes& message) { message.add_repeated_duration()->set_seconds(1); }); } TEST_P(ProtoStructValueTest, TimestampListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_timestamp"), + TEST_HAS_FIELD(memory_manager(), "repeated_timestamp", [](TestAllTypes& message) { message.add_repeated_timestamp()->set_seconds(1); }); } TEST_P(ProtoStructValueTest, EnumListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_nested_enum"), + TEST_HAS_FIELD(memory_manager(), "repeated_nested_enum", [](TestAllTypes& message) { message.add_repeated_nested_enum(TestAllTypes::BAR); }); @@ -354,82 +354,82 @@ TEST_P(ProtoStructValueTest, EnumListHasField) { TEST_P(ProtoStructValueTest, MessageListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_nested_message"), + memory_manager(), "repeated_nested_message", [](TestAllTypes& message) { message.add_repeated_nested_message(); }); } TEST_P(ProtoStructValueTest, BoolWrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_bool_wrapper"), + memory_manager(), "repeated_bool_wrapper", [](TestAllTypes& message) { message.add_repeated_bool_wrapper(); }); } TEST_P(ProtoStructValueTest, Int32WrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_int32_wrapper"), + memory_manager(), "repeated_int32_wrapper", [](TestAllTypes& message) { message.add_repeated_int32_wrapper(); }); } TEST_P(ProtoStructValueTest, Int64WrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_int64_wrapper"), + memory_manager(), "repeated_int64_wrapper", [](TestAllTypes& message) { message.add_repeated_int64_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint32WrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_uint32_wrapper"), + memory_manager(), "repeated_uint32_wrapper", [](TestAllTypes& message) { message.add_repeated_uint32_wrapper(); }); } TEST_P(ProtoStructValueTest, Uint64WrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_uint64_wrapper"), + memory_manager(), "repeated_uint64_wrapper", [](TestAllTypes& message) { message.add_repeated_uint64_wrapper(); }); } TEST_P(ProtoStructValueTest, FloatWrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_float_wrapper"), + memory_manager(), "repeated_float_wrapper", [](TestAllTypes& message) { message.add_repeated_float_wrapper(); }); } TEST_P(ProtoStructValueTest, DoubleWrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_double_wrapper"), + memory_manager(), "repeated_double_wrapper", [](TestAllTypes& message) { message.add_repeated_double_wrapper(); }); } TEST_P(ProtoStructValueTest, BytesWrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_bytes_wrapper"), + memory_manager(), "repeated_bytes_wrapper", [](TestAllTypes& message) { message.add_repeated_bytes_wrapper(); }); } TEST_P(ProtoStructValueTest, StringWrapperListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_string_wrapper"), + memory_manager(), "repeated_string_wrapper", [](TestAllTypes& message) { message.add_repeated_string_wrapper(); }); } TEST_P(ProtoStructValueTest, ListValueListHasField) { TEST_HAS_FIELD( - memory_manager(), FieldId("repeated_list_value"), + memory_manager(), "repeated_list_value", [](TestAllTypes& message) { message.add_repeated_list_value(); }); } TEST_P(ProtoStructValueTest, StructListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_struct"), + TEST_HAS_FIELD(memory_manager(), "repeated_struct", [](TestAllTypes& message) { message.add_repeated_struct(); }); } TEST_P(ProtoStructValueTest, ValueListHasField) { - TEST_HAS_FIELD(memory_manager(), FieldId("repeated_value"), + TEST_HAS_FIELD(memory_manager(), "repeated_value", [](TestAllTypes& message) { message.add_repeated_value(); }); } void TestGetField( - MemoryManager& memory_manager, FieldId id, + MemoryManager& memory_manager, absl::string_view name, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -440,25 +440,25 @@ void TestGetField( ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_OK_AND_ASSIGN(auto field, + value_without->GetFieldByName( + StructValue::GetFieldContext(value_factory), name)); ASSERT_NO_FATAL_FAILURE(unset_field_tester(field)); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_OK_AND_ASSIGN(field, + value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), name)); ASSERT_NO_FATAL_FAILURE(set_field_tester(value_factory, field)); } void TestGetField( - MemoryManager& memory_manager, FieldId id, + MemoryManager& memory_manager, absl::string_view name, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> set_field_tester) { - TestGetField(memory_manager, id, unset_field_tester, test_message_maker, + TestGetField(memory_manager, name, unset_field_tester, test_message_maker, [&](ValueFactory& value_factory, const Handle& field) { set_field_tester(field); }); @@ -468,7 +468,7 @@ void TestGetField( TEST_P(ProtoStructValueTest, NullValueGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("null_value"), + memory_manager(), "null_value", [](const Handle& field) { EXPECT_TRUE(field->Is()); }, [](TestAllTypes& message) { message.set_null_value(NULL_VALUE); }, [](const Handle& field) { EXPECT_TRUE(field->Is()); }); @@ -476,7 +476,7 @@ TEST_P(ProtoStructValueTest, NullValueGetField) { TEST_P(ProtoStructValueTest, OptionalNullValueGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("optional_null_value"), + memory_manager(), "optional_null_value", [](const Handle& field) { EXPECT_TRUE(field->Is()); }, [](TestAllTypes& message) { message.set_optional_null_value(NULL_VALUE); @@ -486,7 +486,7 @@ TEST_P(ProtoStructValueTest, OptionalNullValueGetField) { TEST_P(ProtoStructValueTest, BoolGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_bool"), + memory_manager(), "single_bool", [](const Handle& field) { EXPECT_FALSE(field.As()->value()); }, @@ -498,7 +498,7 @@ TEST_P(ProtoStructValueTest, BoolGetField) { TEST_P(ProtoStructValueTest, Int32GetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_int32"), + memory_manager(), "single_int32", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -510,7 +510,7 @@ TEST_P(ProtoStructValueTest, Int32GetField) { TEST_P(ProtoStructValueTest, Int64GetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_int64"), + memory_manager(), "single_int64", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -522,7 +522,7 @@ TEST_P(ProtoStructValueTest, Int64GetField) { TEST_P(ProtoStructValueTest, Uint32GetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_uint32"), + memory_manager(), "single_uint32", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -534,7 +534,7 @@ TEST_P(ProtoStructValueTest, Uint32GetField) { TEST_P(ProtoStructValueTest, Uint64GetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_uint64"), + memory_manager(), "single_uint64", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -546,7 +546,7 @@ TEST_P(ProtoStructValueTest, Uint64GetField) { TEST_P(ProtoStructValueTest, FloatGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_float"), + memory_manager(), "single_float", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -558,7 +558,7 @@ TEST_P(ProtoStructValueTest, FloatGetField) { TEST_P(ProtoStructValueTest, DoubleGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_double"), + memory_manager(), "single_double", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -570,7 +570,7 @@ TEST_P(ProtoStructValueTest, DoubleGetField) { TEST_P(ProtoStructValueTest, BytesGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_bytes"), + memory_manager(), "single_bytes", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -582,7 +582,7 @@ TEST_P(ProtoStructValueTest, BytesGetField) { TEST_P(ProtoStructValueTest, StringGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_string"), + memory_manager(), "single_string", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -594,7 +594,7 @@ TEST_P(ProtoStructValueTest, StringGetField) { TEST_P(ProtoStructValueTest, DurationGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_duration"), + memory_manager(), "single_duration", [](const Handle& field) { EXPECT_EQ(field.As()->value(), absl::ZeroDuration()); }, @@ -608,7 +608,7 @@ TEST_P(ProtoStructValueTest, DurationGetField) { TEST_P(ProtoStructValueTest, TimestampGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_timestamp"), + memory_manager(), "single_timestamp", [](const Handle& field) { EXPECT_EQ(field.As()->value(), absl::UnixEpoch()); }, @@ -623,7 +623,7 @@ TEST_P(ProtoStructValueTest, TimestampGetField) { TEST_P(ProtoStructValueTest, EnumGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("standalone_enum"), + memory_manager(), "standalone_enum", [](const Handle& field) { EXPECT_EQ(field.As()->number(), 0); }, @@ -637,7 +637,7 @@ TEST_P(ProtoStructValueTest, EnumGetField) { TEST_P(ProtoStructValueTest, MessageGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("standalone_message"), + memory_manager(), "standalone_message", [](const Handle& field) { EXPECT_THAT(*field.As()->value(), EqualsProto(CreateTestMessage().standalone_message())); @@ -663,7 +663,7 @@ TEST_P(ProtoStructValueTest, MessageGetField) { template -void TestGetWrapperField(MemoryManager& memory_manager, FieldId id, +void TestGetWrapperField(MemoryManager& memory_manager, absl::string_view name, UnsetFieldTester&& unset_field_tester, TestMessageMaker&& test_message_maker, SetFieldTester&& set_field_tester) { @@ -675,14 +675,14 @@ void TestGetWrapperField(MemoryManager& memory_manager, FieldId id, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - id)); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory) + .set_unbox_null_wrapper_types(true), + name)); EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN( - field, value_without->GetField(StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - id)); + ASSERT_OK_AND_ASSIGN(field, value_without->GetFieldByName( + StructValue::GetFieldContext(value_factory) + .set_unbox_null_wrapper_types(false), + name)); ASSERT_NO_FATAL_FAILURE( (std::forward(unset_field_tester)(field))); ASSERT_OK_AND_ASSIGN( @@ -690,9 +690,9 @@ void TestGetWrapperField(MemoryManager& memory_manager, FieldId id, ProtoValue::Create(value_factory, CreateTestMessage(std::forward( test_message_maker)))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), id)); + ASSERT_OK_AND_ASSIGN(field, + value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), name)); ASSERT_NO_FATAL_FAILURE( (std::forward(set_field_tester)(field))); } @@ -702,7 +702,7 @@ void TestGetWrapperField(MemoryManager& memory_manager, FieldId id, TEST_P(ProtoStructValueTest, BoolWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_bool_wrapper"), + memory_manager(), "single_bool_wrapper", [](const Handle& field) { EXPECT_FALSE(field.As()->value()); }, @@ -716,7 +716,7 @@ TEST_P(ProtoStructValueTest, BoolWrapperGetField) { TEST_P(ProtoStructValueTest, Int32WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_int32_wrapper"), + memory_manager(), "single_int32_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -730,7 +730,7 @@ TEST_P(ProtoStructValueTest, Int32WrapperGetField) { TEST_P(ProtoStructValueTest, Int64WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_int64_wrapper"), + memory_manager(), "single_int64_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -744,7 +744,7 @@ TEST_P(ProtoStructValueTest, Int64WrapperGetField) { TEST_P(ProtoStructValueTest, Uint32WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_uint32_wrapper"), + memory_manager(), "single_uint32_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -758,7 +758,7 @@ TEST_P(ProtoStructValueTest, Uint32WrapperGetField) { TEST_P(ProtoStructValueTest, Uint64WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_uint64_wrapper"), + memory_manager(), "single_uint64_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -772,7 +772,7 @@ TEST_P(ProtoStructValueTest, Uint64WrapperGetField) { TEST_P(ProtoStructValueTest, FloatWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_float_wrapper"), + memory_manager(), "single_float_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -786,7 +786,7 @@ TEST_P(ProtoStructValueTest, FloatWrapperGetField) { TEST_P(ProtoStructValueTest, DoubleWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_double_wrapper"), + memory_manager(), "single_double_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -800,7 +800,7 @@ TEST_P(ProtoStructValueTest, DoubleWrapperGetField) { TEST_P(ProtoStructValueTest, BytesWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_bytes_wrapper"), + memory_manager(), "single_bytes_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -814,7 +814,7 @@ TEST_P(ProtoStructValueTest, BytesWrapperGetField) { TEST_P(ProtoStructValueTest, StringWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), FieldId("single_string_wrapper"), + memory_manager(), "single_string_wrapper", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -828,7 +828,7 @@ TEST_P(ProtoStructValueTest, StringWrapperGetField) { TEST_P(ProtoStructValueTest, StructGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_struct"), + memory_manager(), "single_struct", [](const Handle& field) { ASSERT_TRUE(field->Is()); EXPECT_TRUE(field->As().empty()); @@ -851,7 +851,7 @@ TEST_P(ProtoStructValueTest, StructGetField) { TEST_P(ProtoStructValueTest, ListValueGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("list_value"), + memory_manager(), "list_value", [](const Handle& field) { ASSERT_TRUE(field->Is()); EXPECT_TRUE(field->As().empty()); @@ -870,7 +870,7 @@ TEST_P(ProtoStructValueTest, ListValueGetField) { TEST_P(ProtoStructValueTest, ValueGetField) { TEST_GET_FIELD( - memory_manager(), FieldId("single_value"), + memory_manager(), "single_value", [](const Handle& field) { EXPECT_TRUE(field->Is()); }, [](TestAllTypes& message) { message.mutable_single_value()->set_bool_value(true); @@ -889,8 +889,8 @@ TEST_P(ProtoStructValueTest, NullValueListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_null_value"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_null_value")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -902,10 +902,9 @@ TEST_P(ProtoStructValueTest, NullValueListGetField) { message.add_repeated_null_value(NULL_VALUE); message.add_repeated_null_value(NULL_VALUE); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_null_value"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_null_value")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -929,8 +928,8 @@ TEST_P(ProtoStructValueTest, BoolListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bool"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_bool")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -943,8 +942,8 @@ TEST_P(ProtoStructValueTest, BoolListGetField) { message.add_repeated_bool(false); }))); ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bool"))); + field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), "repeated_bool")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -968,8 +967,8 @@ TEST_P(ProtoStructValueTest, Int32ListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int32"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_int32")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -981,9 +980,9 @@ TEST_P(ProtoStructValueTest, Int32ListGetField) { message.add_repeated_int32(1); message.add_repeated_int32(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int32"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_int32")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1007,8 +1006,8 @@ TEST_P(ProtoStructValueTest, Int64ListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int64"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_int64")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1020,9 +1019,9 @@ TEST_P(ProtoStructValueTest, Int64ListGetField) { message.add_repeated_int64(1); message.add_repeated_int64(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int64"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_int64")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1046,8 +1045,8 @@ TEST_P(ProtoStructValueTest, Uint32ListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint32"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_uint32")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1059,9 +1058,9 @@ TEST_P(ProtoStructValueTest, Uint32ListGetField) { message.add_repeated_uint32(1); message.add_repeated_uint32(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint32"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_uint32")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1085,8 +1084,8 @@ TEST_P(ProtoStructValueTest, Uint64ListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint64"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_uint64")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1098,9 +1097,9 @@ TEST_P(ProtoStructValueTest, Uint64ListGetField) { message.add_repeated_uint64(1); message.add_repeated_uint64(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint64"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_uint64")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1124,8 +1123,8 @@ TEST_P(ProtoStructValueTest, FloatListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_float"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_float")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1137,9 +1136,9 @@ TEST_P(ProtoStructValueTest, FloatListGetField) { message.add_repeated_float(1.0); message.add_repeated_float(0.0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_float"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_float")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1163,8 +1162,8 @@ TEST_P(ProtoStructValueTest, DoubleListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_double"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_double")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1176,9 +1175,9 @@ TEST_P(ProtoStructValueTest, DoubleListGetField) { message.add_repeated_double(1.0); message.add_repeated_double(0.0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_double"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_double")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1202,8 +1201,8 @@ TEST_P(ProtoStructValueTest, BytesListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bytes"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_bytes")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1215,9 +1214,9 @@ TEST_P(ProtoStructValueTest, BytesListGetField) { message.add_repeated_bytes("foo"); message.add_repeated_bytes("bar"); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bytes"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_bytes")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1241,8 +1240,8 @@ TEST_P(ProtoStructValueTest, StringListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_string"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_string")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1254,9 +1253,9 @@ TEST_P(ProtoStructValueTest, StringListGetField) { message.add_repeated_string("foo"); message.add_repeated_string("bar"); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_string"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_string")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1280,8 +1279,8 @@ TEST_P(ProtoStructValueTest, DurationListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_duration"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_duration")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1293,10 +1292,9 @@ TEST_P(ProtoStructValueTest, DurationListGetField) { message.add_repeated_duration()->set_seconds(1); message.add_repeated_duration()->set_seconds(2); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_duration"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_duration")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1320,8 +1318,8 @@ TEST_P(ProtoStructValueTest, TimestampListGetField) { ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_timestamp"))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_timestamp")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1333,10 +1331,9 @@ TEST_P(ProtoStructValueTest, TimestampListGetField) { message.add_repeated_timestamp()->set_seconds(1); message.add_repeated_timestamp()->set_seconds(2); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_timestamp"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_timestamp")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1361,10 +1358,10 @@ TEST_P(ProtoStructValueTest, EnumListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_nested_enum"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_nested_enum")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1376,10 +1373,9 @@ TEST_P(ProtoStructValueTest, EnumListGetField) { message.add_repeated_nested_enum(TestAllTypes::FOO); message.add_repeated_nested_enum(TestAllTypes::BAR); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_nested_enum"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_nested_enum")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1404,9 +1400,9 @@ TEST_P(ProtoStructValueTest, StructListGetField) { ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( - auto field, value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_nested_message"))); + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_nested_message")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1418,10 +1414,9 @@ TEST_P(ProtoStructValueTest, StructListGetField) { message.add_repeated_nested_message()->set_bb(1); message.add_repeated_nested_message()->set_bb(2); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_nested_message"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_nested_message")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1451,10 +1446,10 @@ TEST_P(ProtoStructValueTest, BoolWrapperListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bool_wrapper"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_bool_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1466,10 +1461,9 @@ TEST_P(ProtoStructValueTest, BoolWrapperListGetField) { message.add_repeated_bool_wrapper()->set_value(true); message.add_repeated_bool_wrapper()->set_value(false); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bool_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_bool_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1491,10 +1485,10 @@ TEST_P(ProtoStructValueTest, Int32WrapperListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int32_wrapper"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_int32_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1506,10 +1500,9 @@ TEST_P(ProtoStructValueTest, Int32WrapperListGetField) { message.add_repeated_int32_wrapper()->set_value(1); message.add_repeated_int32_wrapper()->set_value(0); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int32_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_int32_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1531,10 +1524,10 @@ TEST_P(ProtoStructValueTest, Int64WrapperListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int64_wrapper"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_int64_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1546,10 +1539,9 @@ TEST_P(ProtoStructValueTest, Int64WrapperListGetField) { message.add_repeated_int64_wrapper()->set_value(1); message.add_repeated_int64_wrapper()->set_value(0); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_int64_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_int64_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1572,9 +1564,9 @@ TEST_P(ProtoStructValueTest, Uint32WrapperListGetField) { ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( - auto field, value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint32_wrapper"))); + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_uint32_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1586,10 +1578,9 @@ TEST_P(ProtoStructValueTest, Uint32WrapperListGetField) { message.add_repeated_uint32_wrapper()->set_value(1); message.add_repeated_uint32_wrapper()->set_value(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint32_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_uint32_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1612,9 +1603,9 @@ TEST_P(ProtoStructValueTest, Uint64WrapperListGetField) { ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( - auto field, value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint64_wrapper"))); + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_uint64_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1626,10 +1617,9 @@ TEST_P(ProtoStructValueTest, Uint64WrapperListGetField) { message.add_repeated_uint64_wrapper()->set_value(1); message.add_repeated_uint64_wrapper()->set_value(0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_uint64_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_uint64_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1651,10 +1641,10 @@ TEST_P(ProtoStructValueTest, FloatWrapperListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_float_wrapper"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_float_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1666,10 +1656,9 @@ TEST_P(ProtoStructValueTest, FloatWrapperListGetField) { message.add_repeated_float_wrapper()->set_value(1.0); message.add_repeated_float_wrapper()->set_value(0.0); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_float_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_float_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1692,9 +1681,9 @@ TEST_P(ProtoStructValueTest, DoubleWrapperListGetField) { ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( - auto field, value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_double_wrapper"))); + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_double_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1706,10 +1695,9 @@ TEST_P(ProtoStructValueTest, DoubleWrapperListGetField) { message.add_repeated_double_wrapper()->set_value(1.0); message.add_repeated_double_wrapper()->set_value(0.0); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_double_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_double_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1731,10 +1719,10 @@ TEST_P(ProtoStructValueTest, BytesWrapperListGetField) { ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bytes_wrapper"))); + ASSERT_OK_AND_ASSIGN( + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_bytes_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1746,10 +1734,9 @@ TEST_P(ProtoStructValueTest, BytesWrapperListGetField) { message.add_repeated_bytes_wrapper()->set_value("foo"); message.add_repeated_bytes_wrapper()->set_value("bar"); }))); - ASSERT_OK_AND_ASSIGN( - field, - value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_bytes_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_bytes_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1772,9 +1759,9 @@ TEST_P(ProtoStructValueTest, StringWrapperListGetField) { ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( - auto field, value_without->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_string_wrapper"))); + auto field, + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + "repeated_string_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1786,10 +1773,9 @@ TEST_P(ProtoStructValueTest, StringWrapperListGetField) { message.add_repeated_string_wrapper()->set_value("foo"); message.add_repeated_string_wrapper()->set_value("bar"); }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetField( - StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId("repeated_string_wrapper"))); + ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), + "repeated_string_wrapper")); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1814,10 +1800,9 @@ void TestMapHasField(MemoryManager& memory_manager, ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT( - value_without->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId(map_field_name)), - IsOkAndHolds(Eq(false))); + EXPECT_THAT(value_without->HasFieldByName( + StructValue::HasFieldContext(type_manager), map_field_name), + IsOkAndHolds(Eq(false))); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create( @@ -1826,8 +1811,8 @@ void TestMapHasField(MemoryManager& memory_manager, TestAllTypes& message) mutable { (message.*mutable_map_field)()->insert(std::forward(pair)); }))); - EXPECT_THAT(value_with->HasField(StructValue::HasFieldContext(type_manager), - ProtoStructType::FieldId(map_field_name)), + EXPECT_THAT(value_with->HasFieldByName( + StructValue::HasFieldContext(type_manager), map_field_name), IsOkAndHolds(Eq(true))); } @@ -1881,8 +1866,8 @@ void TestMapGetField(MemoryManager& memory_manager, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId(map_field_name))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + map_field_name)); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1896,8 +1881,8 @@ void TestMapGetField(MemoryManager& memory_manager, (message.*mutable_map_field)()->insert(pair2); }))); ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId(map_field_name))); + field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), map_field_name)); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -1975,8 +1960,8 @@ void TestStringMapGetField(MemoryManager& memory_manager, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId(map_field_name))); + value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), + map_field_name)); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 0); EXPECT_TRUE(field.As()->empty()); @@ -1990,8 +1975,8 @@ void TestStringMapGetField(MemoryManager& memory_manager, (message.*mutable_map_field)()->insert(pair2); }))); ASSERT_OK_AND_ASSIGN( - field, value_with->GetField(StructValue::GetFieldContext(value_factory), - ProtoStructType::FieldId(map_field_name))); + field, value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), map_field_name)); EXPECT_TRUE(field->Is()); EXPECT_EQ(field.As()->size(), 2); EXPECT_FALSE(field.As()->empty()); @@ -3437,6 +3422,8 @@ TEST_P(ProtoStructValueTest, DynamicRValueDifferentDescriptors) { EXPECT_TRUE(value->Is()); } +using ::cel::base_internal::FieldIdFactory; + TEST_P(ProtoStructValueTest, NewFieldIteratorIds) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; @@ -3472,19 +3459,13 @@ TEST_P(ProtoStructValueTest, NewFieldIteratorIds) { EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), CanonicalStatusIs(absl::StatusCode::kFailedPrecondition)); std::set expected_ids = { - StructValue::FieldId("single_bool"), - StructValue::FieldId("single_int32"), - StructValue::FieldId("single_int64"), - StructValue::FieldId("single_uint32"), - StructValue::FieldId("single_uint64"), - StructValue::FieldId("single_float"), - StructValue::FieldId("single_double"), - StructValue::FieldId("single_bytes"), - StructValue::FieldId("single_string"), - StructValue::FieldId("standalone_enum"), - StructValue::FieldId("standalone_message"), - StructValue::FieldId("single_duration"), - StructValue::FieldId("single_timestamp")}; + FieldIdFactory::Make(13), FieldIdFactory::Make(1), + FieldIdFactory::Make(2), FieldIdFactory::Make(3), + FieldIdFactory::Make(4), FieldIdFactory::Make(11), + FieldIdFactory::Make(12), FieldIdFactory::Make(15), + FieldIdFactory::Make(14), FieldIdFactory::Make(24), + FieldIdFactory::Make(23), FieldIdFactory::Make(101), + FieldIdFactory::Make(102)}; EXPECT_EQ(actual_ids, expected_ids); } From 1c1cfd0e5d23476f01db3cac813f4ffca2e514dd Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 11 May 2023 18:56:31 +0000 Subject: [PATCH 261/303] Yet more test simplification PiperOrigin-RevId: 531267255 --- extensions/protobuf/struct_value_test.cc | 1291 ++++++++-------------- 1 file changed, 458 insertions(+), 833 deletions(-) diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index c7cdff05b..a624c4a77 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -880,914 +880,539 @@ TEST_P(ProtoStructValueTest, ValueGetField) { }); } -TEST_P(ProtoStructValueTest, NullValueListGetField) { - TypeFactory type_factory(memory_manager()); +void TestGetListField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_null_value")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); + ASSERT_OK_AND_ASSIGN(auto field, + value_without->GetFieldByName( + StructValue::GetFieldContext(value_factory), name)); + ASSERT_TRUE(field->Is()); + ASSERT_NO_FATAL_FAILURE(unset_field_tester(field.As())); ASSERT_OK_AND_ASSIGN( auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_null_value(NULL_VALUE); - message.add_repeated_null_value(NULL_VALUE); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_null_value")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[null, null]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_TRUE(field_value->Is()); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_TRUE(field_value->Is()); + ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); + ASSERT_OK_AND_ASSIGN(field, + value_with->GetFieldByName( + StructValue::GetFieldContext(value_factory), name)); + ASSERT_TRUE(field->Is()); + ASSERT_NO_FATAL_FAILURE( + set_field_tester(value_factory, field.As())); } -TEST_P(ProtoStructValueTest, BoolListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_bool")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); +#define TEST_GET_LIST_FIELD(...) \ + ASSERT_NO_FATAL_FAILURE(TestGetListField(__VA_ARGS__)) + +void EmptyListFieldTester(const Handle& field) { + EXPECT_EQ(field->size(), 0); + EXPECT_TRUE(field->empty()); EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bool(true); - message.add_repeated_bool(false); - }))); - ASSERT_OK_AND_ASSIGN( - field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), "repeated_bool")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[true, false]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_TRUE(field_value.As()->value()); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_FALSE(field_value.As()->value()); +} + +TEST_P(ProtoStructValueTest, NullValueListGetField) { + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_null_value", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_null_value(NULL_VALUE); + message.add_repeated_null_value(NULL_VALUE); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[null, null]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_TRUE(field_value->Is()); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_TRUE(field_value->Is()); + }); +} + +TEST_P(ProtoStructValueTest, BoolListGetField) { + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_bool", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_bool(true); + message.add_repeated_bool(false); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[true, false]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_TRUE(field_value.As()->value()); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_FALSE(field_value.As()->value()); + }); } TEST_P(ProtoStructValueTest, Int32ListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_int32")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int32(1); - message.add_repeated_int32(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_int32")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1, 0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_int32", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_int32(1); + message.add_repeated_int32(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1, 0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Int64ListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_int64")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int64(1); - message.add_repeated_int64(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_int64")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1, 0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_int64", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_int64(1); + message.add_repeated_int64(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1, 0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Uint32ListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_uint32")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint32(1); - message.add_repeated_uint32(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_uint32")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1u, 0u]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_uint32", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_uint32(1); + message.add_repeated_uint32(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1u, 0u]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Uint64ListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_uint64")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint64(1); - message.add_repeated_uint64(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_uint64")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1u, 0u]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_uint64", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_uint64(1); + message.add_repeated_uint64(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1u, 0u]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, FloatListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_float")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_float(1.0); - message.add_repeated_float(0.0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_float")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1.0); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0.0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_float", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_float(1.0); + message.add_repeated_float(0.0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1.0); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0.0); + }); } TEST_P(ProtoStructValueTest, DoubleListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_double")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_double(1.0); - message.add_repeated_double(0.0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_double")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1.0); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0.0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_double", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_double(1.0); + message.add_repeated_double(0.0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1.0); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0.0); + }); } TEST_P(ProtoStructValueTest, BytesListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_bytes")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bytes("foo"); - message.add_repeated_bytes("bar"); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_bytes")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[b\"foo\", b\"bar\"]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->ToString(), "foo"); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->ToString(), "bar"); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_bytes", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_bytes("foo"); + message.add_repeated_bytes("bar"); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[b\"foo\", b\"bar\"]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->ToString(), "foo"); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->ToString(), "bar"); + }); } TEST_P(ProtoStructValueTest, StringListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_string")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_string("foo"); - message.add_repeated_string("bar"); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_string")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[\"foo\", \"bar\"]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->ToString(), "foo"); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->ToString(), "bar"); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_string", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_string("foo"); + message.add_repeated_string("bar"); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[\"foo\", \"bar\"]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->ToString(), "foo"); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->ToString(), "bar"); + }); } TEST_P(ProtoStructValueTest, DurationListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_duration")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_duration()->set_seconds(1); - message.add_repeated_duration()->set_seconds(2); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_duration")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1s, 2s]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), absl::Seconds(1)); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), absl::Seconds(2)); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_duration", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_duration()->set_seconds(1); + message.add_repeated_duration()->set_seconds(2); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1s, 2s]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), absl::Seconds(1)); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), absl::Seconds(2)); + }); } TEST_P(ProtoStructValueTest, TimestampListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_timestamp")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_timestamp()->set_seconds(1); - message.add_repeated_timestamp()->set_seconds(2); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_timestamp")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), - "[1970-01-01T00:00:01Z, 1970-01-01T00:00:02Z]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), - absl::UnixEpoch() + absl::Seconds(1)); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), - absl::UnixEpoch() + absl::Seconds(2)); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_timestamp", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_timestamp()->set_seconds(1); + message.add_repeated_timestamp()->set_seconds(2); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), + "[1970-01-01T00:00:01Z, 1970-01-01T00:00:02Z]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), + absl::UnixEpoch() + absl::Seconds(1)); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), + absl::UnixEpoch() + absl::Seconds(2)); + }); } TEST_P(ProtoStructValueTest, EnumListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_nested_enum")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_nested_enum(TestAllTypes::FOO); - message.add_repeated_nested_enum(TestAllTypes::BAR); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_nested_enum")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_nested_enum", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_nested_enum(TestAllTypes::FOO); + message.add_repeated_nested_enum(TestAllTypes::BAR); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ( + field->DebugString(), "[google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum.FOO, " "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum.BAR]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->number(), TestAllTypes::FOO); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->number(), TestAllTypes::BAR); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->number(), TestAllTypes::FOO); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->number(), TestAllTypes::BAR); + }); } TEST_P(ProtoStructValueTest, StructListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_nested_message")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_nested_message()->set_bb(1); - message.add_repeated_nested_message()->set_bb(2); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_nested_message")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ( - field->DebugString(), - "[google.api.expr.test.v1.proto3.TestAllTypes.NestedMessage{bb: 1}, " - "google.api.expr.test.v1.proto3.TestAllTypes.NestedMessage{bb: 2}]"); - TestAllTypes::NestedMessage message; - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - message.set_bb(1); - EXPECT_THAT(*field_value.As()->value(), - EqualsProto(message)); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - message.set_bb(2); - EXPECT_THAT(*field_value.As()->value(), - EqualsProto(message)); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_nested_message", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_nested_message()->set_bb(1); + message.add_repeated_nested_message()->set_bb(2); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), + "[google.api.expr.test.v1.proto3.TestAllTypes.NestedMessage{" + "bb: 1}, " + "google.api.expr.test.v1.proto3.TestAllTypes.NestedMessage{" + "bb: 2}]"); + TestAllTypes::NestedMessage message; + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + message.set_bb(1); + EXPECT_THAT(*field_value.As()->value(), + EqualsProto(message)); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + message.set_bb(2); + EXPECT_THAT(*field_value.As()->value(), + EqualsProto(message)); + }); } TEST_P(ProtoStructValueTest, BoolWrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_bool_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bool_wrapper()->set_value(true); - message.add_repeated_bool_wrapper()->set_value(false); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_bool_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[true, false]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_TRUE(field_value.As()->value()); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_FALSE(field_value.As()->value()); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_bool_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_bool_wrapper()->set_value(true); + message.add_repeated_bool_wrapper()->set_value(false); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[true, false]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_TRUE(field_value.As()->value()); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_FALSE(field_value.As()->value()); + }); } TEST_P(ProtoStructValueTest, Int32WrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_int32_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int32_wrapper()->set_value(1); - message.add_repeated_int32_wrapper()->set_value(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_int32_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1, 0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_int32_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_int32_wrapper()->set_value(1); + message.add_repeated_int32_wrapper()->set_value(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1, 0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Int64WrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_int64_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_int64_wrapper()->set_value(1); - message.add_repeated_int64_wrapper()->set_value(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_int64_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1, 0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_int64_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_int64_wrapper()->set_value(1); + message.add_repeated_int64_wrapper()->set_value(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1, 0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Uint32WrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_uint32_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint32_wrapper()->set_value(1); - message.add_repeated_uint32_wrapper()->set_value(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_uint32_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1u, 0u]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_uint32_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_uint32_wrapper()->set_value(1); + message.add_repeated_uint32_wrapper()->set_value(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1u, 0u]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, Uint64WrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_uint64_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_uint64_wrapper()->set_value(1); - message.add_repeated_uint64_wrapper()->set_value(0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_uint64_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1u, 0u]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_uint64_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_uint64_wrapper()->set_value(1); + message.add_repeated_uint64_wrapper()->set_value(0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1u, 0u]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0); + }); } TEST_P(ProtoStructValueTest, FloatWrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_float_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_float_wrapper()->set_value(1.0); - message.add_repeated_float_wrapper()->set_value(0.0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_float_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1.0); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0.0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_float_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_float_wrapper()->set_value(1.0); + message.add_repeated_float_wrapper()->set_value(0.0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1.0); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0.0); + }); } TEST_P(ProtoStructValueTest, DoubleWrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_double_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_double_wrapper()->set_value(1.0); - message.add_repeated_double_wrapper()->set_value(0.0); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_double_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->value(), 1.0); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->value(), 0.0); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_double_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_double_wrapper()->set_value(1.0); + message.add_repeated_double_wrapper()->set_value(0.0); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[1.0, 0.0]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->value(), 1.0); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->value(), 0.0); + }); } TEST_P(ProtoStructValueTest, BytesWrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_bytes_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_bytes_wrapper()->set_value("foo"); - message.add_repeated_bytes_wrapper()->set_value("bar"); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_bytes_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[b\"foo\", b\"bar\"]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->ToString(), "foo"); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->ToString(), "bar"); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_bytes_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_bytes_wrapper()->set_value("foo"); + message.add_repeated_bytes_wrapper()->set_value("bar"); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[b\"foo\", b\"bar\"]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->ToString(), "foo"); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->ToString(), "bar"); + }); } TEST_P(ProtoStructValueTest, StringWrapperListGetField) { - TypeFactory type_factory(memory_manager()); - ProtoTypeProvider type_provider; - TypeManager type_manager(type_factory, type_provider); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto value_without, - ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN( - auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory), - "repeated_string_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 0); - EXPECT_TRUE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[]"); - ASSERT_OK_AND_ASSIGN( - auto value_with, - ProtoValue::Create( - value_factory, CreateTestMessage([](TestAllTypes& message) { - message.add_repeated_string_wrapper()->set_value("foo"); - message.add_repeated_string_wrapper()->set_value("bar"); - }))); - ASSERT_OK_AND_ASSIGN(field, value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), - "repeated_string_wrapper")); - EXPECT_TRUE(field->Is()); - EXPECT_EQ(field.As()->size(), 2); - EXPECT_FALSE(field.As()->empty()); - EXPECT_EQ(field->DebugString(), "[\"foo\", \"bar\"]"); - ASSERT_OK_AND_ASSIGN( - auto field_value, - field.As()->Get(ListValue::GetContext(value_factory), 0)); - EXPECT_EQ(field_value.As()->ToString(), "foo"); - ASSERT_OK_AND_ASSIGN( - field_value, - field.As()->Get(ListValue::GetContext(value_factory), 1)); - EXPECT_EQ(field_value.As()->ToString(), "bar"); + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_string_wrapper", EmptyListFieldTester, + [](TestAllTypes& message) { + message.add_repeated_string_wrapper()->set_value("foo"); + message.add_repeated_string_wrapper()->set_value("bar"); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 2); + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->DebugString(), "[\"foo\", \"bar\"]"); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ(field_value.As()->ToString(), "foo"); + ASSERT_OK_AND_ASSIGN( + field_value, field->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ(field_value.As()->ToString(), "bar"); + }); } template From 203c057c512464ddb877c8d12cbe168a4ebe91de Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 11 May 2023 21:38:15 +0000 Subject: [PATCH 262/303] Run field presence and access tests for both name and number PiperOrigin-RevId: 531310682 --- extensions/protobuf/BUILD | 1 + extensions/protobuf/struct_value_test.cc | 283 +++++++++++++++++++---- 2 files changed, 234 insertions(+), 50 deletions(-) diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index c7b24fa94..fcd9117c7 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -195,6 +195,7 @@ cc_test( "//extensions/protobuf/internal:testing", "//internal:testing", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index a624c4a77..3588fc59e 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -22,6 +22,7 @@ #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/functional/function_ref.h" +#include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/time/time.h" #include "absl/types/optional.h" @@ -92,28 +93,65 @@ T Must(absl::StatusOr status_or) { return Must(std::move(status_or).value()); } -template -void TestHasField(MemoryManager& memory_manager, absl::string_view name, - TestMessageMaker&& test_message_maker, bool found = true) { +int TestMessageFieldNameToNumber(absl::string_view name) { + const auto* descriptor = TestAllTypes::descriptor(); + return ABSL_DIE_IF_NULL(descriptor->FindFieldByName(name))->number(); +} + +void TestHasFieldImpl( + MemoryManager& memory_manager, + absl::FunctionRef(const Handle&, + const StructValue::HasFieldContext&)> + has_field, + absl::FunctionRef test_message_maker, bool found) { TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - EXPECT_THAT(value_without->HasFieldByName( - StructValue::HasFieldContext(type_manager), name), - IsOkAndHolds(Eq(false))); + EXPECT_THAT( + has_field(value_without, StructValue::HasFieldContext(type_manager)), + IsOkAndHolds(Eq(false))); ASSERT_OK_AND_ASSIGN( auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage(std::forward( - test_message_maker)))); - EXPECT_THAT(value_with->HasFieldByName( - StructValue::HasFieldContext(type_manager), name), + ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); + EXPECT_THAT(has_field(value_with, StructValue::HasFieldContext(type_manager)), IsOkAndHolds(Eq(found))); } +void TestHasFieldByName( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef test_message_maker, bool found) { + TestHasFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::HasFieldContext& context) { + return value->HasFieldByName(context, name); + }, + test_message_maker, found); +} + +void TestHasFieldByNumber( + MemoryManager& memory_manager, int64_t number, + absl::FunctionRef test_message_maker, bool found) { + TestHasFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::HasFieldContext& context) { + return value->HasFieldByNumber(context, number); + }, + test_message_maker, found); +} + +void TestHasField(MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef test_message_maker, + bool found = true) { + TestHasFieldByName(memory_manager, name, test_message_maker, found); + TestHasFieldByNumber(memory_manager, TestMessageFieldNameToNumber(name), + test_message_maker, found); +} + #define TEST_HAS_FIELD(...) ASSERT_NO_FATAL_FAILURE(TestHasField(__VA_ARGS__)) TEST_P(ProtoStructValueTest, NullValueHasField) { @@ -428,8 +466,11 @@ TEST_P(ProtoStructValueTest, ValueListHasField) { [](TestAllTypes& message) { message.add_repeated_value(); }); } -void TestGetField( - MemoryManager& memory_manager, absl::string_view name, +void TestGetFieldImpl( + MemoryManager& memory_manager, + absl::FunctionRef>( + const Handle&, const StructValue::GetFieldContext&)> + get_field, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -440,19 +481,62 @@ void TestGetField( ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetFieldByName( - StructValue::GetFieldContext(value_factory), name)); + ASSERT_OK_AND_ASSIGN( + auto field, + get_field(value_without, StructValue::GetFieldContext(value_factory))); ASSERT_NO_FATAL_FAILURE(unset_field_tester(field)); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); - ASSERT_OK_AND_ASSIGN(field, - value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), name)); + ASSERT_OK_AND_ASSIGN( + field, + get_field(value_with, StructValue::GetFieldContext(value_factory))); ASSERT_NO_FATAL_FAILURE(set_field_tester(value_factory, field)); } +void TestGetFieldByName( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByName(context, name); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetFieldByNumber( + MemoryManager& memory_manager, int64_t number, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByNumber(context, number); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetFieldByName(memory_manager, name, unset_field_tester, + test_message_maker, set_field_tester); + TestGetFieldByNumber(memory_manager, TestMessageFieldNameToNumber(name), + unset_field_tester, test_message_maker, + set_field_tester); +} + void TestGetField( MemoryManager& memory_manager, absl::string_view name, absl::FunctionRef&)> unset_field_tester, @@ -661,12 +745,15 @@ TEST_P(ProtoStructValueTest, MessageGetField) { }); } -template -void TestGetWrapperField(MemoryManager& memory_manager, absl::string_view name, - UnsetFieldTester&& unset_field_tester, - TestMessageMaker&& test_message_maker, - SetFieldTester&& set_field_tester) { +void TestGetWrapperFieldImpl( + MemoryManager& memory_manager, + absl::FunctionRef>( + const Handle&, const StructValue::GetFieldContext&)> + get_field, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { TypeFactory type_factory(memory_manager); ProtoTypeProvider type_provider; TypeManager type_manager(type_factory, type_provider); @@ -675,26 +762,76 @@ void TestGetWrapperField(MemoryManager& memory_manager, absl::string_view name, ProtoValue::Create(value_factory, CreateTestMessage())); ASSERT_OK_AND_ASSIGN( auto field, - value_without->GetFieldByName(StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(true), - name)); + get_field(value_without, StructValue::GetFieldContext(value_factory) + .set_unbox_null_wrapper_types(true))); EXPECT_TRUE(field->Is()); - ASSERT_OK_AND_ASSIGN(field, value_without->GetFieldByName( - StructValue::GetFieldContext(value_factory) - .set_unbox_null_wrapper_types(false), - name)); - ASSERT_NO_FATAL_FAILURE( - (std::forward(unset_field_tester)(field))); + ASSERT_OK_AND_ASSIGN( + field, + get_field(value_without, StructValue::GetFieldContext(value_factory) + .set_unbox_null_wrapper_types(false))); + ASSERT_NO_FATAL_FAILURE(unset_field_tester(field)); ASSERT_OK_AND_ASSIGN( auto value_with, - ProtoValue::Create(value_factory, - CreateTestMessage(std::forward( - test_message_maker)))); - ASSERT_OK_AND_ASSIGN(field, - value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), name)); - ASSERT_NO_FATAL_FAILURE( - (std::forward(set_field_tester)(field))); + ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); + ASSERT_OK_AND_ASSIGN( + field, + get_field(value_with, StructValue::GetFieldContext(value_factory))); + ASSERT_NO_FATAL_FAILURE(set_field_tester(value_factory, field)); +} + +void TestGetWrapperFieldByName( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetWrapperFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByName(context, name); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetWrapperFieldByNumber( + MemoryManager& memory_manager, int64_t number, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetWrapperFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByNumber(context, number); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetWrapperField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetWrapperFieldByName(memory_manager, name, unset_field_tester, + test_message_maker, set_field_tester); + TestGetWrapperFieldByNumber( + memory_manager, TestMessageFieldNameToNumber(name), unset_field_tester, + test_message_maker, set_field_tester); +} + +void TestGetWrapperField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> set_field_tester) { + TestGetWrapperField( + memory_manager, name, unset_field_tester, test_message_maker, + [&](ValueFactory& value_factory, const Handle& field) { + set_field_tester(field); + }); } #define TEST_GET_WRAPPER_FIELD(...) \ @@ -880,8 +1017,11 @@ TEST_P(ProtoStructValueTest, ValueGetField) { }); } -void TestGetListField( - MemoryManager& memory_manager, absl::string_view name, +void TestGetListFieldImpl( + MemoryManager& memory_manager, + absl::FunctionRef>( + const Handle&, const StructValue::GetFieldContext&)> + get_field, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -892,22 +1032,65 @@ void TestGetListField( ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto value_without, ProtoValue::Create(value_factory, CreateTestMessage())); - ASSERT_OK_AND_ASSIGN(auto field, - value_without->GetFieldByName( - StructValue::GetFieldContext(value_factory), name)); + ASSERT_OK_AND_ASSIGN( + auto field, + get_field(value_without, StructValue::GetFieldContext(value_factory))); ASSERT_TRUE(field->Is()); ASSERT_NO_FATAL_FAILURE(unset_field_tester(field.As())); ASSERT_OK_AND_ASSIGN( auto value_with, ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); - ASSERT_OK_AND_ASSIGN(field, - value_with->GetFieldByName( - StructValue::GetFieldContext(value_factory), name)); + ASSERT_OK_AND_ASSIGN( + field, + get_field(value_with, StructValue::GetFieldContext(value_factory))); ASSERT_TRUE(field->Is()); ASSERT_NO_FATAL_FAILURE( set_field_tester(value_factory, field.As())); } +void TestGetListFieldByName( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetListFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByName(context, name); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetListFieldByNumber( + MemoryManager& memory_manager, int64_t number, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetListFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByNumber(context, number); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetListField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetListFieldByName(memory_manager, name, unset_field_tester, + test_message_maker, set_field_tester); + TestGetListFieldByNumber(memory_manager, TestMessageFieldNameToNumber(name), + unset_field_tester, test_message_maker, + set_field_tester); +} + #define TEST_GET_LIST_FIELD(...) \ ASSERT_NO_FATAL_FAILURE(TestGetListField(__VA_ARGS__)) From 46392037322444852091f5d852de5ef1e4b57969 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Thu, 11 May 2023 21:43:10 +0000 Subject: [PATCH 263/303] Update flat expr builder program optimizer API to accept a factory instead of an instance. PiperOrigin-RevId: 531312047 --- eval/compiler/BUILD | 1 + eval/compiler/constant_folding.cc | 15 ++++------- eval/compiler/constant_folding.h | 2 +- eval/compiler/constant_folding_test.cc | 28 ++++++++++---------- eval/compiler/flat_expr_builder.cc | 11 ++++---- eval/compiler/flat_expr_builder.h | 4 +-- eval/compiler/flat_expr_builder_extensions.h | 19 ++++++++++--- 7 files changed, 44 insertions(+), 36 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index f413a278e..c1e8740da 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -20,6 +20,7 @@ cc_library( "//runtime:runtime_options", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index ceb2a9237..8ee28bf0c 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -393,14 +393,6 @@ class ConstantFoldingExtension : public ProgramOptimizer { ConstantFoldingExtension(int stack_limit, google::protobuf::Arena* arena) : arena_(arena), state_(stack_limit, arena) {} - absl::Status OnInit(google::api::expr::runtime::PlannerContext& context, - const AstImpl& ast) override { - // Clean up const stack incase of failure in the middle of planning previous - // expression. - is_const_.clear(); - return absl::OkStatus(); - } - absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, @@ -528,10 +520,13 @@ void FoldConstants( constant_folder.Transform(ast, out_ast); } -std::unique_ptr +google::api::expr::runtime::ProgramOptimizerFactory CreateConstantFoldingExtension(google::protobuf::Arena* arena, ConstantFoldingOptions options) { - return std::make_unique(options.stack_limit, arena); + return [=](PlannerContext&, const AstImpl&) { + return std::make_unique(options.stack_limit, + arena); + }; } } // namespace cel::ast::internal diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index 76d5a3f69..74641c73b 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -34,7 +34,7 @@ struct ConstantFoldingOptions { // Create a new constant folding extension. // Eagerly evaluates sub expressions with all constant inputs, and replaces said // sub expression with the result. -std::unique_ptr +google::api::expr::runtime::ProgramOptimizerFactory CreateConstantFoldingExtension( google::protobuf::Arena* arena, ConstantFoldingOptions options = ConstantFoldingOptions()); diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index 542750346..a5276ff6a 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -48,6 +48,7 @@ using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::ExecutionPath; using ::google::api::expr::runtime::PlannerContext; using ::google::api::expr::runtime::ProgramOptimizer; +using ::google::api::expr::runtime::ProgramOptimizerFactory; using ::google::api::expr::runtime::Resolver; using ::google::protobuf::Arena; using testing::SizeIs; @@ -606,12 +607,13 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - std::unique_ptr constant_folder = + ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, ast_impl)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, condition)); ASSERT_OK(constant_folder->OnPostVisit(context, condition)); @@ -672,12 +674,13 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - std::unique_ptr constant_folder = + ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, ast_impl)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); @@ -736,12 +739,13 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - std::unique_ptr constant_folder = + ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingExtension(&arena, {kStackLimit}); // Act // Issue the visitation calls. - ASSERT_OK(constant_folder->OnInit(context, ast_impl)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, ast_impl)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); @@ -800,16 +804,12 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { google::protobuf::Arena arena; constexpr int kStackLimit = 1; - std::unique_ptr constant_folder = + ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingExtension(&arena, {kStackLimit}); - // Act - // Issue the visitation calls in wrong order. - ASSERT_OK(constant_folder->OnPreVisit(context, call)); - ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); - ASSERT_OK(constant_folder->OnInit(context, ast_impl)); - - // ASSERT + // Act / Assert + ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, + constant_folder_factory(context, ast_impl)); EXPECT_THAT(constant_folder->OnPostVisit(context, left_condition), StatusIs(absl::StatusCode::kInternal)); } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 0170f6e2e..4f87341ad 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1329,15 +1329,16 @@ FlatExprBuilder::CreateExpressionImpl( auto arena = std::make_unique(); - for (const std::unique_ptr& optimizer : - program_optimizers_) { - CEL_RETURN_IF_ERROR(optimizer->OnInit(extension_context, ast_impl)); + std::vector> optimizers; + for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { + CEL_ASSIGN_OR_RETURN(optimizers.emplace_back(), + optimizer_factory(extension_context, ast_impl)); } FlatExprVisitor visitor( resolver, options_, constant_idents, enable_comprehension_vulnerability_check_, enable_regex_precompilation_, - program_optimizers_, &ast_impl.reference_map(), &execution_path, - &warnings_builder, arena.get(), program_tree, extension_context); + optimizers, &ast_impl.reference_map(), &execution_path, &warnings_builder, + arena.get(), program_tree, extension_context); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 8898a15ca..6bb808fec 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -67,7 +67,7 @@ class FlatExprBuilder : public CelExpressionBuilder { ast_transforms_.push_back(std::move(transform)); } - void AddProgramOptimizer(std::unique_ptr optimizer) { + void AddProgramOptimizer(ProgramOptimizerFactory optimizer) { program_optimizers_.push_back(std::move(optimizer)); } @@ -103,7 +103,7 @@ class FlatExprBuilder : public CelExpressionBuilder { cel::RuntimeOptions options_; std::vector> ast_transforms_; - std::vector> program_optimizers_; + std::vector program_optimizers_; bool enable_regex_precompilation_ = false; bool enable_comprehension_vulnerability_check_ = false; diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 3c0eac70a..edd134eab 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -22,9 +22,11 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ +#include #include #include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "base/ast.h" #include "base/ast_internal.h" @@ -106,10 +108,6 @@ class ProgramOptimizer { public: virtual ~ProgramOptimizer() = default; - // Called once before program planning begins for the given AST. - virtual absl::Status OnInit(PlannerContext& context, - const cel::ast::internal::AstImpl& ast) = 0; - // Called before planning the given expr node. virtual absl::Status OnPreVisit(PlannerContext& context, const cel::ast::internal::Expr& node) = 0; @@ -119,6 +117,19 @@ class ProgramOptimizer { const cel::ast::internal::Expr& node) = 0; }; +// Type definition for ProgramOptimizer factories. +// +// The expression builder must remain thread compatible, but ProgramOptimizers +// are often stateful for a given expression. To avoid requiring the optimizer +// implementation to handle concurrent planning, the builder creates a new +// instance per expression planned. +// +// The factory must be thread safe, but the returned instance may assume +// it is called from a synchronous context. +using ProgramOptimizerFactory = + absl::AnyInvocable>( + PlannerContext&, const cel::ast::internal::AstImpl&) const>; + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ From e11d97198ff9f248f7e5c1d3a1145cb175adb4ec Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 11 May 2023 21:50:32 +0000 Subject: [PATCH 264/303] Internal tool change PiperOrigin-RevId: 531313837 --- eval/testutil/test_message.proto | 1 + 1 file changed, 1 insertion(+) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 3fae9f915..0c85be2c8 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -67,6 +67,7 @@ message TestMessage { map int64_enum_map = 208; map string_timestamp_map = 209; map string_message_map = 210; + map int64_timestamp_map = 211; // Well-known types. google.protobuf.Any any_value = 300; From c0518ac55ee4201d8aa151c3d9ade558fb4e579a Mon Sep 17 00:00:00 2001 From: jdtatum Date: Thu, 11 May 2023 22:11:49 +0000 Subject: [PATCH 265/303] Adjust stack limits for evaluating sub expressions in the constant folding optimization. PiperOrigin-RevId: 531319295 --- eval/compiler/constant_folding.cc | 28 +++++++++++++++++--------- eval/compiler/constant_folding.h | 14 +------------ eval/compiler/constant_folding_test.cc | 12 ++++------- eval/eval/evaluator_stack.h | 13 +++++++----- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 8ee28bf0c..6ddf660b7 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -43,7 +43,9 @@ using ::cel::interop_internal::CreateErrorValueFromView; using ::cel::interop_internal::CreateLegacyListValue; using ::cel::interop_internal::CreateNoMatchingOverloadError; using ::cel::interop_internal::ModernValueToLegacyValueOrDie; +using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelEvaluationListener; +using ::google::api::expr::runtime::CelExpressionFlatEvaluationState; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; using ::google::api::expr::runtime::ExecutionFrame; @@ -390,8 +392,8 @@ bool ConstantFoldingTransform::Transform(const Expr& expr, Expr& out_) { class ConstantFoldingExtension : public ProgramOptimizer { public: - ConstantFoldingExtension(int stack_limit, google::protobuf::Arena* arena) - : arena_(arena), state_(stack_limit, arena) {} + explicit ConstantFoldingExtension(google::protobuf::Arena* arena) + : arena_(arena), state_(kDefaultStackLimit, arena) {} absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; @@ -403,11 +405,14 @@ class ConstantFoldingExtension : public ProgramOptimizer { kConditional, kNonConst, }; + // Most constant folding evaluations are simple + // binary operators. + static constexpr size_t kDefaultStackLimit = 4; google::protobuf::Arena* arena_; - google::api::expr::runtime::Activation empty_; - google::api::expr::runtime::CelEvaluationListener null_listener_; - google::api::expr::runtime::CelExpressionFlatEvaluationState state_; + Activation empty_; + CelEvaluationListener null_listener_; + CelExpressionFlatEvaluationState state_; std::vector is_const_; }; @@ -440,7 +445,7 @@ absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, IsConst operator()(absl::monostate) { return IsConst::kNonConst; } IsConst operator()(const Call& call) { - // Shortcircuiting operators not yet supported. + // Short Circuiting operators not yet supported. if (call.function() == kAnd || call.function() == kOr || call.function() == kTernary) { return IsConst::kNonConst; @@ -496,6 +501,11 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, ExecutionFrame frame(subplan, empty_, &context.type_registry(), context.options(), &state_); state_.Reset(); + // Update stack size to accommodate sub expression. + // This only results in a vector resize if the new maxsize is greater than + // the current capacity. + state_.value_stack().SetMaxSize(subplan.size()); + CEL_ASSIGN_OR_RETURN(value, frame.Evaluate(null_listener_)); if (value->Is()) { return absl::OkStatus(); @@ -521,11 +531,9 @@ void FoldConstants( } google::api::expr::runtime::ProgramOptimizerFactory -CreateConstantFoldingExtension(google::protobuf::Arena* arena, - ConstantFoldingOptions options) { +CreateConstantFoldingExtension(google::protobuf::Arena* arena) { return [=](PlannerContext&, const AstImpl&) { - return std::make_unique(options.stack_limit, - arena); + return std::make_unique(arena); }; } diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index 74641c73b..77326f8aa 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -21,23 +21,11 @@ void FoldConstants( absl::flat_hash_map>& constant_idents, Expr& out_ast); -struct ConstantFoldingOptions { - // Stack limit for evaluating constant sub expressions. - // Should accommodate the maximum expected number of dependencies for a small - // subexpression (e.g. number of elements in a list). - // - // 64 is sufficient to support map literals with 32 key/value pairs per the - // minimum required support in the CEL spec. - int stack_limit = 64; -}; - // Create a new constant folding extension. // Eagerly evaluates sub expressions with all constant inputs, and replaces said // sub expression with the result. google::api::expr::runtime::ProgramOptimizerFactory -CreateConstantFoldingExtension( - google::protobuf::Arena* arena, - ConstantFoldingOptions options = ConstantFoldingOptions()); +CreateConstantFoldingExtension(google::protobuf::Arena* arena); } // namespace cel::ast::internal diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index a5276ff6a..0d91339b9 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -606,9 +606,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { path, tree); google::protobuf::Arena arena; - constexpr int kStackLimit = 1; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingExtension(&arena, {kStackLimit}); + CreateConstantFoldingExtension(&arena); // Act // Issue the visitation calls. @@ -673,9 +672,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { path, tree); google::protobuf::Arena arena; - constexpr int kStackLimit = 1; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingExtension(&arena, {kStackLimit}); + CreateConstantFoldingExtension(&arena); // Act // Issue the visitation calls. @@ -738,9 +736,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { path, tree); google::protobuf::Arena arena; - constexpr int kStackLimit = 1; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingExtension(&arena, {kStackLimit}); + CreateConstantFoldingExtension(&arena); // Act // Issue the visitation calls. @@ -803,9 +800,8 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { path, tree); google::protobuf::Arena arena; - constexpr int kStackLimit = 1; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingExtension(&arena, {kStackLimit}); + CreateConstantFoldingExtension(&arena); // Act / Assert ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, diff --git a/eval/eval/evaluator_stack.h b/eval/eval/evaluator_stack.h index 8348b3b73..63ace19fb 100644 --- a/eval/eval/evaluator_stack.h +++ b/eval/eval/evaluator_stack.h @@ -124,19 +124,22 @@ class EvaluatorStack { attribute_stack_[current_size_ - 1] = std::move(attribute); } + // Update the max size of the stack and update capacity if needed. + void SetMaxSize(size_t size) { + max_size_ = size; + Reserve(size); + } + + private: // Preallocate stack. void Reserve(size_t size) { - if (size > max_size()) { - size = max_size(); - } stack_.reserve(size); attribute_stack_.reserve(size); } - private: std::vector> stack_; std::vector attribute_stack_; - const size_t max_size_; + size_t max_size_; size_t current_size_; }; From cec04633e5a50f76f5681d64f85493ca7d67324f Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 12 May 2023 18:22:34 +0000 Subject: [PATCH 266/303] Internal tool update PiperOrigin-RevId: 531558145 --- eval/testutil/test_message.proto | 1 - 1 file changed, 1 deletion(-) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index 0c85be2c8..8369dba35 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -74,7 +74,6 @@ message TestMessage { google.protobuf.Duration duration_value = 301; google.protobuf.Timestamp timestamp_value = 302; google.protobuf.Struct struct_value = 303; - // TODO(issues/5): Test null_value with variable bindings. google.protobuf.Value value_value = 304; google.protobuf.Int64Value int64_wrapper_value = 305; google.protobuf.Int32Value int32_wrapper_value = 306; From 52604a1d80474908662bc78eee92505d684fc8f7 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 12 May 2023 19:32:33 +0000 Subject: [PATCH 267/303] Cache builtin types in `TypeManager` PiperOrigin-RevId: 531576309 --- base/type.cc | 15 +++++++++ base/type.h | 7 +++++ base/type_manager.cc | 67 +++++++++++++++++++++++++++++++++-------- base/type_manager.h | 7 +++-- base/types/dyn_type.cc | 8 +++++ base/types/dyn_type.h | 6 ++++ base/types/list_type.cc | 11 +++++++ base/types/list_type.h | 5 +++ base/types/map_type.cc | 10 ++++++ base/types/map_type.h | 5 +++ 10 files changed, 126 insertions(+), 15 deletions(-) diff --git a/base/type.cc b/base/type.cc index 08922ebde..dcdda19b1 100644 --- a/base/type.cc +++ b/base/type.cc @@ -19,6 +19,7 @@ #include #include "absl/base/optimization.h" +#include "absl/strings/string_view.h" #include "base/internal/data.h" #include "base/types/any_type.h" #include "base/types/bool_type.h" @@ -92,6 +93,20 @@ absl::string_view Type::name() const { } } +absl::Span Type::aliases() const { + switch (kind()) { + case Kind::kDyn: + return static_cast(this)->aliases(); + case Kind::kList: + return static_cast(this)->aliases(); + case Kind::kMap: + return static_cast(this)->aliases(); + default: + // Everything else does not support aliases. + return absl::Span(); + } +} + std::string Type::DebugString() const { switch (kind()) { case Kind::kNullType: diff --git a/base/type.h b/base/type.h index 764e9cd9a..395b09d1d 100644 --- a/base/type.h +++ b/base/type.h @@ -24,6 +24,7 @@ #include "absl/base/optimization.h" #include "absl/hash/hash.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/handle.h" #include "base/internal/data.h" #include "base/internal/type.h" // IWYU pragma: export @@ -90,6 +91,7 @@ class Type : public base_internal::Data { } private: + friend class TypeManager; friend class EnumType; friend class StructType; friend class ListType; @@ -100,6 +102,11 @@ class Type : public base_internal::Data { friend class base_internal::TypeHandle; friend class OpaqueType; + // This is used by TypeManager to determine whether a type has any known + // aliases. This is currently only used for JSON-like types. Pretend this + // doesn't exist. + absl::Span aliases() const; + static bool Equals(const Type& lhs, const Type& rhs, Kind kind); static bool Equals(const Type& lhs, const Type& rhs) { diff --git a/base/type_manager.cc b/base/type_manager.cc index fb2fd8b77..9bbbf9598 100644 --- a/base/type_manager.cc +++ b/base/type_manager.cc @@ -17,38 +17,79 @@ #include #include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "internal/status_macros.h" namespace cel { absl::StatusOr>> TypeManager::ResolveType( absl::string_view name) { - { - // Check for builtin types first. - CEL_ASSIGN_OR_RETURN( - auto type, TypeProvider::Builtin().ProvideType(type_factory(), name)); - if (type) { - return type; - } - } - // Check with the type registry. + // Check the cached types. { absl::ReaderMutexLock lock(&mutex_); auto existing = types_.find(name); - if (existing != types_.end()) { + if (ABSL_PREDICT_TRUE(existing != types_.end())) { return existing->second; } } + // Check for builtin types. + TypeProvider& builtin_type_provider = TypeProvider::Builtin(); + { + CEL_ASSIGN_OR_RETURN( + auto type, builtin_type_provider.ProvideType(type_factory(), name)); + if (type) { + absl::string_view provided_name = (*type)->name(); + // We do not check that `provided_name` matches name for the builtin type + // provider. There are some special types that have aliases. + return CacheTypeWithAliases(provided_name, std::move(type).value()); + } + } + if (ABSL_PREDICT_FALSE(&builtin_type_provider == &type_provider())) { + return absl::nullopt; + } // Delegate to TypeRegistry implementation. CEL_ASSIGN_OR_RETURN(auto type, type_provider().ProvideType(type_factory(), name)); - if (!type) { + if (ABSL_PREDICT_FALSE(!type)) { return absl::nullopt; } - ABSL_ASSERT(name == (*type)->name()); + absl::string_view provided_name = (*type)->name(); + if (ABSL_PREDICT_FALSE(name != provided_name)) { + return absl::InternalError( + absl::StrCat("TypeProvider provided ", provided_name, " for ", name)); + } + return CacheType(provided_name, std::move(type).value()); +} + +Handle TypeManager::CacheType(absl::string_view name, + Handle&& type) { + ABSL_ASSERT(name == type->name()); + absl::WriterMutexLock lock(&mutex_); + return types_.insert({name, std::move(type)}).first->second; +} + +Handle TypeManager::CacheTypeWithAliases(absl::string_view name, + Handle&& type) { + absl::Span aliases = type->aliases(); + if (aliases.empty()) { + return CacheType(name, std::move(type)); + } absl::WriterMutexLock lock(&mutex_); - return types_.insert({(*type)->name(), std::move(*type)}).first->second; + auto insertion = types_.insert({name, type}); + if (insertion.second) { + // Somebody beat us to caching. + return insertion.first->second; + } + for (const auto& alias : aliases) { + insertion = types_.insert({alias, type}); + ABSL_ASSERT(insertion.second); + } + return std::move(type); } } // namespace cel diff --git a/base/type_manager.h b/base/type_manager.h index da8c360dc..ccefeab42 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -15,8 +15,6 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ -#include - #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" @@ -52,6 +50,11 @@ class TypeManager final { absl::string_view name); private: + Handle CacheType(absl::string_view name, Handle&& type); + + Handle CacheTypeWithAliases(absl::string_view name, + Handle&& type); + TypeFactory& type_factory_; TypeProvider& type_provider_; diff --git a/base/types/dyn_type.cc b/base/types/dyn_type.cc index d3639054f..24a8b812d 100644 --- a/base/types/dyn_type.cc +++ b/base/types/dyn_type.cc @@ -14,8 +14,16 @@ #include "base/types/dyn_type.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + namespace cel { CEL_INTERNAL_TYPE_IMPL(DynType); +absl::Span DynType::aliases() const { + // Currently google.protobuf.Value also resolves to dyn. + return absl::MakeConstSpan({absl::string_view("google.protobuf.Value")}); +} + } // namespace cel diff --git a/base/types/dyn_type.h b/base/types/dyn_type.h index 008e1e9bf..1dc509697 100644 --- a/base/types/dyn_type.h +++ b/base/types/dyn_type.h @@ -16,6 +16,8 @@ #define THIRD_PARTY_CEL_CPP_BASE_TYPES_DYN_TYPE_H_ #include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/kind.h" #include "base/type.h" @@ -46,9 +48,13 @@ class DynType final : public base_internal::SimpleType { using Base::DebugString; private: + friend class Type; friend class base_internal::LegacyListType; friend class base_internal::LegacyMapType; + // See Type::aliases(). + absl::Span aliases() const; + CEL_INTERNAL_SIMPLE_TYPE_MEMBERS(DynType, DynValue); }; diff --git a/base/types/list_type.cc b/base/types/list_type.cc index 190a26b18..cda767bf7 100644 --- a/base/types/list_type.cc +++ b/base/types/list_type.cc @@ -19,6 +19,8 @@ #include "absl/base/macros.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/internal/data.h" #include "base/types/dyn_type.h" @@ -26,6 +28,15 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(ListType); +absl::Span ListType::aliases() const { + if (element()->kind() == Kind::kDyn) { + // Currently google.protobuf.ListValue resolves to list. + return absl::MakeConstSpan( + {absl::string_view("google.protobuf.ListValue")}); + } + return absl::Span(); +} + std::string ListType::DebugString() const { return absl::StrCat(name(), "(", element()->DebugString(), ")"); } diff --git a/base/types/list_type.h b/base/types/list_type.h index 879fb3fad..2e7e11e1c 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -21,6 +21,7 @@ #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" @@ -55,12 +56,16 @@ class ListType : public Type { } private: + friend class Type; friend class MemoryManager; friend class TypeFactory; friend class base_internal::TypeHandle; friend class base_internal::LegacyListType; friend class base_internal::ModernListType; + // See Type::aliases(). + absl::Span aliases() const; + ListType() = default; }; diff --git a/base/types/map_type.cc b/base/types/map_type.cc index 518d2a783..82bfac3d2 100644 --- a/base/types/map_type.cc +++ b/base/types/map_type.cc @@ -19,6 +19,8 @@ #include "absl/base/macros.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/internal/data.h" #include "base/types/dyn_type.h" @@ -26,6 +28,14 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(MapType); +absl::Span MapType::aliases() const { + if (key()->kind() == Kind::kString && value()->kind() == Kind::kDyn) { + // Currently google.protobuf.Struct resolves to map. + return absl::MakeConstSpan({absl::string_view("google.protobuf.Struct")}); + } + return absl::Span(); +} + std::string MapType::DebugString() const { return absl::StrCat(name(), "(", key()->DebugString(), ", ", value()->DebugString(), ")"); diff --git a/base/types/map_type.h b/base/types/map_type.h index 807f2303c..57df887b4 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -21,6 +21,7 @@ #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" @@ -59,12 +60,16 @@ class MapType : public Type { const Handle& value() const; private: + friend class Type; friend class MemoryManager; friend class TypeFactory; friend class base_internal::TypeHandle; friend class base_internal::LegacyMapType; friend class base_internal::ModernMapType; + // See Type::aliases(). + absl::Span aliases() const; + MapType() = default; }; From 8751ac338f6b65bd647f32bd632b9efd9266c9df Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 15 May 2023 17:35:07 +0000 Subject: [PATCH 268/303] Remove `final` from `cel::Allocator` PiperOrigin-RevId: 532156138 --- base/memory.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/memory.h b/base/memory.h index 9fde6cb4a..835a0bf06 100644 --- a/base/memory.h +++ b/base/memory.h @@ -260,7 +260,7 @@ class ArenaMemoryManager : public MemoryManager { // STL allocator implementation which is backed by MemoryManager. template -class Allocator final { +class Allocator { public: using value_type = T; using pointer = T*; From b44791e1d6fbc66381cdac8253ff3a21bad8184f Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 15 May 2023 17:51:47 +0000 Subject: [PATCH 269/303] Use `struct` consistently with `cel::base_internal::AnyData` to silence warnings PiperOrigin-RevId: 532161387 --- base/type.h | 2 +- base/types/list_type.h | 2 +- base/types/map_type.h | 2 +- base/types/struct_type.h | 2 +- base/types/wrapper_type.h | 12 ++++++------ base/value.h | 2 +- base/values/bytes_value.h | 4 ++-- base/values/enum_value.h | 2 +- base/values/error_value.h | 2 +- base/values/list_value.h | 2 +- base/values/map_value.h | 2 +- base/values/string_value.h | 4 ++-- base/values/struct_value.h | 2 +- base/values/type_value.h | 6 +++--- base/values/unknown_value.h | 2 +- 15 files changed, 24 insertions(+), 24 deletions(-) diff --git a/base/type.h b/base/type.h index 395b09d1d..caee5e043 100644 --- a/base/type.h +++ b/base/type.h @@ -375,7 +375,7 @@ CEL_INTERNAL_TYPE_DECL(Type); template \ friend class base_internal::SimpleValue; \ template \ - friend class base_internal::AnyData; \ + friend struct base_internal::AnyData; \ \ ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); \ \ diff --git a/base/types/list_type.h b/base/types/list_type.h index 2e7e11e1c..cea6c2870 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -87,7 +87,7 @@ class LegacyListType final : public ListType, public InlineData { friend class cel::ListType; friend class base_internal::TypeHandle; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = base_internal::kStoredInline | base_internal::kTrivial | diff --git a/base/types/map_type.h b/base/types/map_type.h index 57df887b4..792661ed3 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -92,7 +92,7 @@ class LegacyMapType final : public MapType, public InlineData { friend class cel::MapType; friend class base_internal::TypeHandle; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = base_internal::kStoredInline | base_internal::kTrivial | diff --git a/base/types/struct_type.h b/base/types/struct_type.h index d7ce4bca9..3b3a0358b 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -265,7 +265,7 @@ class LegacyStructType final : public StructType, public InlineData { friend class cel::StructType; friend class LegacyStructValue; template - friend class AnyData; + friend struct AnyData; explicit LegacyStructType(uintptr_t msg) : StructType(), InlineData(kMetadata), msg_(msg) {} diff --git a/base/types/wrapper_type.h b/base/types/wrapper_type.h index 7a73b3cff..9de378010 100644 --- a/base/types/wrapper_type.h +++ b/base/types/wrapper_type.h @@ -117,7 +117,7 @@ class BoolWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); @@ -154,7 +154,7 @@ class BytesWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); @@ -191,7 +191,7 @@ class DoubleWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); @@ -228,7 +228,7 @@ class IntWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); @@ -265,7 +265,7 @@ class StringWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); @@ -302,7 +302,7 @@ class UintWrapperType final : public WrapperType { private: friend class TypeFactory; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; ABSL_ATTRIBUTE_PURE_FUNCTION static const Handle& Get(); diff --git a/base/value.h b/base/value.h index 4af859689..21ba4ea5b 100644 --- a/base/value.h +++ b/base/value.h @@ -323,7 +323,7 @@ CEL_INTERNAL_VALUE_DECL(Value); friend class ValueFactory; \ friend class base_internal::ValueHandle; \ template \ - friend class base_internal::AnyData; \ + friend struct base_internal::AnyData; \ \ value_class() = default; \ value_class(const value_class&) = default; \ diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index 728e56ba1..d90311980 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -121,7 +121,7 @@ class InlinedCordBytesValue final : public BytesValue, public InlineData { private: friend class BytesValue; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | AsInlineVariant(InlinedBytesValueVariant::kCord) | @@ -144,7 +144,7 @@ class InlinedStringViewBytesValue final : public BytesValue, public InlineData { private: friend class BytesValue; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | (static_cast(kKind) << kKindShift); diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 19e2fd753..8b10cdbba 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -68,7 +68,7 @@ class EnumValue final : public Value, public base_internal::InlineData { private: friend class base_internal::ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; static constexpr uintptr_t kMetadata = base_internal::kStoredInline | diff --git a/base/values/error_value.h b/base/values/error_value.h index 3c8c06196..e7b1a8cef 100644 --- a/base/values/error_value.h +++ b/base/values/error_value.h @@ -58,7 +58,7 @@ class ErrorValue final : public Value, public base_internal::InlineData { private: friend class ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; friend struct interop_internal::ErrorValueAccess; static constexpr uintptr_t kMetadata = diff --git a/base/values/list_value.h b/base/values/list_value.h index 513757036..9322351f9 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -180,7 +180,7 @@ class LegacyListValue final : public ListValue, public InlineData { friend class base_internal::ValueHandle; friend class cel::ListValue; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | kTrivial | (static_cast(kKind) << kKindShift); diff --git a/base/values/map_value.h b/base/values/map_value.h index 2f7477385..ca75538ac 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -210,7 +210,7 @@ class LegacyMapValue final : public MapValue, public InlineData { friend class base_internal::ValueHandle; friend class cel::MapValue; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | kTrivial | (static_cast(kKind) << kKindShift); diff --git a/base/values/string_value.h b/base/values/string_value.h index 058dad924..426f3b44a 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -136,7 +136,7 @@ class InlinedCordStringValue final : public StringValue, public InlineData { friend class StringValue; friend class ValueFactory; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | AsInlineVariant(InlinedStringValueVariant::kCord) | @@ -160,7 +160,7 @@ class InlinedStringViewStringValue final : public StringValue, private: friend class StringValue; template - friend class AnyData; + friend struct AnyData; static constexpr uintptr_t kMetadata = kStoredInline | (static_cast(kKind) << kKindShift); diff --git a/base/values/struct_value.h b/base/values/struct_value.h index af3ba91bb..1e42f68c0 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -262,7 +262,7 @@ class LegacyStructValue final : public StructValue, public InlineData { friend class base_internal::ValueHandle; friend class cel::StructValue; template - friend class AnyData; + friend struct AnyData; friend struct interop_internal::LegacyStructValueAccess; static constexpr uintptr_t kMetadata = diff --git a/base/values/type_value.h b/base/values/type_value.h index a9ef6a7b5..09bf3f398 100644 --- a/base/values/type_value.h +++ b/base/values/type_value.h @@ -56,7 +56,7 @@ class TypeValue : public Value { private: friend class ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; friend class base_internal::LegacyTypeValue; friend class base_internal::ModernTypeValue; @@ -78,7 +78,7 @@ class LegacyTypeValue final : public TypeValue, InlineData { private: friend class ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; static constexpr uintptr_t kMetadata = kStoredInline | kTrivial | (static_cast(kKind) << kKindShift); @@ -103,7 +103,7 @@ class ModernTypeValue final : public TypeValue, InlineData { private: friend class ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; static constexpr uintptr_t kMetadata = kStoredInline | (static_cast(kKind) << kKindShift); diff --git a/base/values/unknown_value.h b/base/values/unknown_value.h index f57965b8a..8f559a89e 100644 --- a/base/values/unknown_value.h +++ b/base/values/unknown_value.h @@ -54,7 +54,7 @@ class UnknownValue final : public Value, public base_internal::InlineData { private: friend class ValueHandle; template - friend class base_internal::AnyData; + friend struct base_internal::AnyData; friend struct interop_internal::UnknownValueAccess; static constexpr uintptr_t kMetadata = From 74dea69078bc1b1bad31767bd5ff4d5c845afa57 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 15 May 2023 19:34:08 +0000 Subject: [PATCH 270/303] Implement initial support for `google.protobuf.Any` PiperOrigin-RevId: 532191709 --- base/internal/memory_manager_testing.h | 5 + base/type_provider.cc | 7 +- base/type_provider_test.cc | 7 + extensions/protobuf/BUILD | 1 + extensions/protobuf/internal/BUILD | 18 ++ extensions/protobuf/internal/descriptors.cc | 130 +++++++++++ extensions/protobuf/internal/descriptors.h | 61 +++++ extensions/protobuf/struct_type.h | 2 + extensions/protobuf/struct_value.cc | 68 +++++- extensions/protobuf/struct_value_test.cc | 42 ++++ extensions/protobuf/type.cc | 10 +- extensions/protobuf/type.h | 10 +- extensions/protobuf/value.cc | 237 ++++++++++++++++++++ extensions/protobuf/value.h | 21 +- extensions/protobuf/value_test.cc | 210 +++++++++++++++++ 15 files changed, 813 insertions(+), 16 deletions(-) create mode 100644 extensions/protobuf/internal/descriptors.cc create mode 100644 extensions/protobuf/internal/descriptors.h diff --git a/base/internal/memory_manager_testing.h b/base/internal/memory_manager_testing.h index e62e11853..946660fec 100644 --- a/base/internal/memory_manager_testing.h +++ b/base/internal/memory_manager_testing.h @@ -29,6 +29,11 @@ enum class MemoryManagerTestMode { std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode); +template +void AbslStringify(S& sink, MemoryManagerTestMode mode) { + sink.Append(MemoryManagerTestModeToString(mode)); +} + inline auto MemoryManagerTestModeAll() { return testing::Values(MemoryManagerTestMode::kGlobal, MemoryManagerTestMode::kArena); diff --git a/base/type_provider.cc b/base/type_provider.cc index c60e73980..0c9b79b5e 100644 --- a/base/type_provider.cc +++ b/base/type_provider.cc @@ -50,6 +50,7 @@ class BuiltinTypeProvider final : public TypeProvider { {"google.protobuf.Struct", GetStructType}, {"type", GetTypeType}, {"google.protobuf.Value", GetValueType}, + {"google.protobuf.Any", GetAnyType}, {"google.protobuf.BoolValue", GetBoolWrapperType}, {"google.protobuf.BytesValue", GetBytesWrapperType}, {"google.protobuf.DoubleValue", GetDoubleWrapperType}, @@ -172,7 +173,11 @@ class BuiltinTypeProvider final : public TypeProvider { return type_factory.GetDynType(); } - std::array types_; + static absl::StatusOr> GetAnyType(TypeFactory& type_factory) { + return type_factory.GetAnyType(); + } + + std::array types_; }; } // namespace diff --git a/base/type_provider_test.cc b/base/type_provider_test.cc index 3e99dc390..a65139779 100644 --- a/base/type_provider_test.cc +++ b/base/type_provider_test.cc @@ -135,6 +135,13 @@ TEST_P(BuiltinTypeProviderTest, ProvidesStructWrapperType) { EXPECT_TRUE(struct_type->As()->value()->Is()); } +TEST_P(BuiltinTypeProviderTest, ProvidesAnyType) { + TypeFactory type_factory(memory_manager()); + ASSERT_THAT( + TypeProvider::Builtin().ProvideType(type_factory, "google.protobuf.Any"), + IsOkAndHolds(Optional(Eq(type_factory.GetAnyType())))); +} + TEST_P(BuiltinTypeProviderTest, DoesNotProvide) { TypeFactory type_factory(memory_manager()); ASSERT_THAT( diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index fcd9117c7..2cda3c67d 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -192,6 +192,7 @@ cc_test( "//base:value", "//base/internal:memory_manager_testing", "//base/testing:value_matchers", + "//extensions/protobuf/internal:descriptors", "//extensions/protobuf/internal:testing", "//internal:testing", "@com_google_absl//absl/functional:function_ref", diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index 475195c52..76a83e684 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -19,6 +19,24 @@ package( licenses(["notice"]) +cc_library( + name = "descriptors", + testonly = True, + srcs = ["descriptors.cc"], + hdrs = ["descriptors.h"], + deps = [ + "//base:memory", + "//base:type", + "//extensions/protobuf:memory_manager", + "//extensions/protobuf:type", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "map_reflection", srcs = ["map_reflection.cc"], diff --git a/extensions/protobuf/internal/descriptors.cc b/extensions/protobuf/internal/descriptors.cc new file mode 100644 index 000000000..d974e7a80 --- /dev/null +++ b/extensions/protobuf/internal/descriptors.cc @@ -0,0 +1,130 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/descriptors.h" + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "extensions/protobuf/memory_manager.h" +#include "extensions/protobuf/type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel::extensions::protobuf_internal { + +namespace { + +class DescriptorGathererImpl final : public DescriptorGatherer { + public: + DescriptorGathererImpl() = default; + + void Gather(const google::protobuf::Descriptor& descriptor) override { + GatherFile(*descriptor.file()); + } + + std::unique_ptr Finish() override { + auto database = std::make_unique(); + for (auto& file : files_) { + ABSL_CHECK(database->AddAndOwn(file.second.release())); // Crash OK + } + visited_.clear(); + files_.clear(); + return database; + } + + private: + void GatherFile(const google::protobuf::FileDescriptor& descriptor) { + if (!Visit(descriptor.name())) { + return; + } + descriptor.CopyTo(&File(descriptor)); + int dependency_count = descriptor.dependency_count(); + for (int dependency_index = 0; dependency_index < dependency_count; + ++dependency_index) { + GatherFile(*descriptor.dependency(dependency_index)); + } + } + + bool Visit(absl::string_view name) { return visited_.insert(name).second; } + + google::protobuf::FileDescriptorProto& File(const google::protobuf::FileDescriptor& descriptor) { + return File(descriptor.name()); + } + + google::protobuf::FileDescriptorProto& File(absl::string_view name) { + auto& file = files_[name]; + if (file == nullptr) { + file = std::make_unique(); + } + file->set_name(name); + return *file; + } + + absl::flat_hash_set visited_; + absl::flat_hash_map> + files_; +}; + +} // namespace + +std::unique_ptr NewDescriptorGatherer() { + return std::make_unique(); +} + +void WithCustomDescriptorPool( + MemoryManager& memory_manager, const google::protobuf::Message& message, + const google::protobuf::Descriptor& additional_descriptor, + absl::FunctionRef invocable) { + std::unique_ptr database; + { + auto gatherer = NewDescriptorGatherer(); + gatherer->Gather(*message.GetDescriptor()); + gatherer->Gather(additional_descriptor); + database = gatherer->Finish(); + } + google::protobuf::DescriptorPool pool(database.get()); + google::protobuf::DynamicMessageFactory message_factory(&pool); + message_factory.SetDelegateToGeneratedFactory(false); + const auto* descriptor = + pool.FindMessageTypeByName(message.GetDescriptor()->full_name()); + ABSL_CHECK(descriptor != nullptr) // Crash OK + << "Unable to get descriptor for " + << message.GetDescriptor()->full_name(); + const auto* prototype = message_factory.GetPrototype(descriptor); + ABSL_CHECK(prototype != nullptr) // Crash OK + << "Unable to get prototype for " << descriptor->full_name(); + google::protobuf::Arena* arena = nullptr; + if (ProtoMemoryManager::Is(memory_manager)) { + arena = ProtoMemoryManager::CastToProtoArena(memory_manager); + } + auto* custom = prototype->New(arena); + { + absl::Cord serialized; + ABSL_CHECK(message.SerializePartialToCord(&serialized)); + ABSL_CHECK(custom->ParsePartialFromCord(serialized)); + } + ProtoTypeProvider type_provider(&pool, &message_factory); + invocable(type_provider, *custom); + if (arena == nullptr) { + delete custom; + } +} + +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/descriptors.h b/extensions/protobuf/internal/descriptors.h new file mode 100644 index 000000000..777c06eff --- /dev/null +++ b/extensions/protobuf/internal/descriptors.h @@ -0,0 +1,61 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_DESCRIPTORS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_DESCRIPTORS_H_ + +#include + +#include "absl/functional/function_ref.h" +#include "base/memory.h" +#include "base/type_provider.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::extensions::protobuf_internal { + +// Interface capable of collecting `google::protobuf::FileDescriptorProto` relevant to the +// provided `google::protobuf::Descriptor` and creating a `google::protobuf::DescriptorDatabase`. +class DescriptorGatherer { + public: + virtual ~DescriptorGatherer() = default; + + virtual void Gather(const google::protobuf::Descriptor& descriptor) = 0; + + virtual std::unique_ptr Finish() = 0; +}; + +std::unique_ptr NewDescriptorGatherer(); + +// Converts a `google::protobuf::Message` which is a generated message into the equivalent +// dynamic message. This is done by copying all the relevant descriptors into a +// custom descriptor database and creating a custom descriptor pool and message +// factory. +void WithCustomDescriptorPool( + MemoryManager& memory_manager, const google::protobuf::Message& message, + const google::protobuf::Descriptor& additional_descriptor, + absl::FunctionRef invocable); +inline void WithCustomDescriptorPool( + MemoryManager& memory_manager, const google::protobuf::Message& message, + absl::FunctionRef invocable) { + WithCustomDescriptorPool(memory_manager, message, *message.GetDescriptor(), + invocable); +} + +} // namespace cel::extensions::protobuf_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_DESCRIPTORS_H_ diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index 852e77836..a0257a5dd 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -33,6 +33,7 @@ namespace cel::extensions { class ProtoTypeProvider; class ProtoStructValue; class ProtoType; +class ProtoValue; namespace protobuf_internal { class ParsedProtoStructValue; } @@ -75,6 +76,7 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { private: friend class ProtoStructTypeFieldIterator; friend class ProtoType; + friend class ProtoValue; friend class ProtoTypeProvider; friend class ProtoStructValue; friend class protobuf_internal::ParsedProtoStructValue; diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 8ead55cc2..a2f6623e5 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -60,6 +60,7 @@ #include "extensions/protobuf/memory_manager.h" #include "extensions/protobuf/struct_type.h" #include "extensions/protobuf/type.h" +#include "extensions/protobuf/value.h" #include "internal/status_macros.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" @@ -148,10 +149,7 @@ std::unique_ptr ProtoStructValue::value( } std::unique_ptr ProtoStructValue::value() const { - return absl::WrapUnique( - ValuePointer(*ABSL_DIE_IF_NULL( // Crash OK - google::protobuf::MessageFactory::generated_factory()), - nullptr)); + return absl::WrapUnique(ValuePointer(*type()->factory_, nullptr)); } google::protobuf::Message* ProtoStructValue::value( @@ -160,9 +158,7 @@ google::protobuf::Message* ProtoStructValue::value( } google::protobuf::Message* ProtoStructValue::value(google::protobuf::Arena& arena) const { - return ValuePointer(*ABSL_DIE_IF_NULL( // Crash OK - google::protobuf::MessageFactory::generated_factory()), - &arena); + return ValuePointer(*type()->factory_, &arena); } namespace { @@ -933,6 +929,50 @@ class ParsedProtoListValue const google::protobuf::RepeatedFieldRef fields_; }; +// repeated google.protobuf.Any +template <> +class ParsedProtoListValue + : public CEL_LIST_VALUE_CLASS { + public: + ParsedProtoListValue(Handle type, + google::protobuf::RepeatedFieldRef fields) + : CEL_LIST_VALUE_CLASS(std::move(type)), fields_(std::move(fields)) {} + + std::string DebugString() const final { + std::string out; + out.push_back('['); + auto field = fields_.begin(); + if (field != fields_.end()) { + ProtoDebugStringStruct(out, *field); + ++field; + for (; field != fields_.end(); ++field) { + out.append(", "); + ProtoDebugStringStruct(out, *field); + } + } + out.push_back(']'); + return out; + } + + size_t size() const final { return fields_.size(); } + + bool empty() const final { return fields_.empty(); } + + absl::StatusOr> Get(const GetContext& context, + size_t index) const final { + std::unique_ptr scratch(fields_.NewMessage()); + const auto& field = fields_.Get(static_cast(index), scratch.get()); + return ProtoValue::Create(context.value_factory(), field); + } + + private: + cel::internal::TypeInfo TypeId() const final { + return internal::TypeId>(); + } + + const google::protobuf::RepeatedFieldRef fields_; +}; + // repeated google.protobuf.BoolValue template <> class ParsedProtoListValue @@ -1577,6 +1617,9 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { return protobuf_internal::CreateBorrowedValue( owner_from_this(), context.value_factory(), proto_value.GetMessageValue()); + case Kind::kAny: + return ProtoValue::Create(context.value_factory(), + proto_value.GetMessageValue()); case Kind::kBool: { // google.protobuf.BoolValue, mapped to CEL primitive bool type for // map values. @@ -2354,6 +2397,13 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); + case Kind::kAny: + return context.value_factory() + .CreateBorrowedListValue< + ParsedProtoListValue>( + owner_from_this(), field.type.As(), + reflect.GetRepeatedFieldRef(value(), + &field_desc)); case Kind::kBool: // google.protobuf.BoolValue, mapped to CEL primitive bool type for // list elements. @@ -2525,6 +2575,10 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return protobuf_internal::CreateBorrowedValue( owner_from_this(), context.value_factory(), reflect.GetMessage(value(), &field_desc)); + case Kind::kAny: + // google.protobuf.Any + return ProtoValue::Create(context.value_factory(), + reflect.GetMessage(value(), &field_desc)); case Kind::kWrapper: { if (context.unbox_null_wrapper_types() && !reflect.HasField(value(), &field_desc)) { diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index 3588fc59e..e83beded4 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -311,6 +311,11 @@ TEST_P(ProtoStructValueTest, ValueHasField) { [](TestAllTypes& message) { message.mutable_single_value(); }); } +TEST_P(ProtoStructValueTest, AnyHasField) { + TEST_HAS_FIELD(memory_manager(), "single_any", + [](TestAllTypes& message) { message.mutable_single_any(); }); +} + TEST_P(ProtoStructValueTest, NullValueListHasField) { TEST_HAS_FIELD(memory_manager(), "repeated_null_value", [](TestAllTypes& message) { @@ -466,6 +471,11 @@ TEST_P(ProtoStructValueTest, ValueListHasField) { [](TestAllTypes& message) { message.add_repeated_value(); }); } +TEST_P(ProtoStructValueTest, AnyListHasField) { + TEST_HAS_FIELD(memory_manager(), "repeated_any", + [](TestAllTypes& message) { message.add_repeated_any(); }); +} + void TestGetFieldImpl( MemoryManager& memory_manager, absl::FunctionRef>( @@ -1017,6 +1027,20 @@ TEST_P(ProtoStructValueTest, ValueGetField) { }); } +TEST_P(ProtoStructValueTest, AnyGetField) { + TEST_GET_FIELD( + memory_manager(), "single_any", + [](const Handle& field) { EXPECT_TRUE(field->Is()); }, + [](TestAllTypes& message) { + google::protobuf::BoolValue proto; + proto.set_value(true); + ASSERT_TRUE(message.mutable_single_any()->PackFrom(proto)); + }, + [](const Handle& field) { + EXPECT_TRUE(field->As().value()); + }); +} + void TestGetListFieldImpl( MemoryManager& memory_manager, absl::FunctionRef>( @@ -1598,6 +1622,24 @@ TEST_P(ProtoStructValueTest, StringWrapperListGetField) { }); } +TEST_P(ProtoStructValueTest, AnyListGetField) { + TEST_GET_LIST_FIELD( + memory_manager(), "repeated_any", EmptyListFieldTester, + [](TestAllTypes& message) { + google::protobuf::BoolValue proto; + proto.set_value(true); + ASSERT_TRUE(message.add_repeated_any()->PackFrom(proto)); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_EQ(field->size(), 1); + EXPECT_FALSE(field->empty()); + ASSERT_OK_AND_ASSIGN( + auto field_value, + field->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_TRUE(field_value.As()->value()); + }); +} + template void TestMapHasField(MemoryManager& memory_manager, absl::string_view map_field_name, diff --git a/extensions/protobuf/type.cc b/extensions/protobuf/type.cc index 342cc43ed..084042052 100644 --- a/extensions/protobuf/type.cc +++ b/extensions/protobuf/type.cc @@ -63,11 +63,11 @@ absl::StatusOr> ProtoType::Resolve( absl::StrCat("Missing protocol buffer type implementation for \"", descriptor.full_name(), "\"")); } - if (ABSL_PREDICT_FALSE(!(*type)->Is() && - !(*type)->Is() && - !(*type)->Is() && - !(*type)->Is() && !IsJsonList(**type) && - !IsJsonMap(**type) && !(*type)->Is())) { + if (ABSL_PREDICT_FALSE( + !(*type)->Is() && !(*type)->Is() && + !(*type)->Is() && !(*type)->Is() && + !IsJsonList(**type) && !IsJsonMap(**type) && + !(*type)->Is() && !(*type)->Is())) { return absl::FailedPreconditionError( absl::StrCat("Unexpected protocol buffer type implementation for \"", descriptor.full_name(), "\": ", (*type)->DebugString())); diff --git a/extensions/protobuf/type.h b/extensions/protobuf/type.h index 2f87fca12..10a34d5eb 100644 --- a/extensions/protobuf/type.h +++ b/extensions/protobuf/type.h @@ -17,6 +17,7 @@ #include +#include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -126,6 +127,12 @@ class ProtoType final { template using NotWrapperMessage = std::negation>; + template + using AnyMessage = std::is_same>; + + template + using NotAnyMessage = std::negation>; + public: // Resolve Type from a protocol buffer enum descriptor. static absl::StatusOr> Resolve( @@ -155,7 +162,8 @@ class ProtoType final { template static std::enable_if_t< std::conjunction_v, NotDurationMessage, - NotTimestampMessage, NotWrapperMessage>, + NotTimestampMessage, NotWrapperMessage, + NotAnyMessage>, absl::StatusOr>> Resolve(TypeManager& type_manager) { return ProtoStructType::Resolve(type_manager); diff --git a/extensions/protobuf/value.cc b/extensions/protobuf/value.cc index 8665a2f8a..ac2e6b2ec 100644 --- a/extensions/protobuf/value.cc +++ b/extensions/protobuf/value.cc @@ -21,7 +21,10 @@ #include #include +#include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/base/macros.h" @@ -32,6 +35,7 @@ #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "base/handle.h" @@ -144,6 +148,8 @@ absl::StatusOr> CreateMemberJsonValue( ValueFactory& value_factory, const google::protobuf::Value& value, HandleFromThis&& owner_from_this) { switch (value.kind_case()) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Value::kNullValue: return value_factory.GetNullValue(); case google::protobuf::Value::kBoolValue: @@ -1318,6 +1324,70 @@ absl::StatusOr> ValueMessageOwnConverter( } } +absl::StatusOr> AnyMessageCopyConverter( + ValueFactory& value_factory, const google::protobuf::Message& value) { + const auto* descriptor = value.GetDescriptor(); + if (descriptor == google::protobuf::Any::descriptor()) { + return ProtoValue::Create( + value_factory, + cel::internal::down_cast(value)); + } + const auto* reflect = value.GetReflection(); + if (ABSL_PREDICT_FALSE(reflect == nullptr)) { + return absl::InvalidArgumentError( + "reflection missing for google.protobuf.Any"); + } + const auto* type_url_field = + descriptor->FindFieldByNumber(google::protobuf::Any::kTypeUrlFieldNumber); + if (ABSL_PREDICT_FALSE(type_url_field == nullptr)) { + return absl::InvalidArgumentError( + "type_url field descriptor missing for google.protobuf.Any"); + } + if (ABSL_PREDICT_FALSE(type_url_field->is_repeated() || + type_url_field->is_map() || + type_url_field->cpp_type() != + google::protobuf::FieldDescriptor::CPPTYPE_STRING)) { + return absl::InvalidArgumentError( + "type_url field descriptor has unexpected type"); + } + const auto* value_field = + descriptor->FindFieldByNumber(google::protobuf::Any::kValueFieldNumber); + if (ABSL_PREDICT_FALSE(value_field == nullptr)) { + return absl::InvalidArgumentError( + "value field descriptor missing for google.protobuf.Any"); + } + if (ABSL_PREDICT_FALSE(value_field->is_repeated() || value_field->is_map() || + value_field->cpp_type() != + google::protobuf::FieldDescriptor::CPPTYPE_STRING)) { + return absl::InvalidArgumentError( + "value field descriptor has unexpected type"); + } + std::string type_url_storage; + absl::string_view type_url; + if (!type_url_field->is_extension() && + type_url_field->options().ctype() == google::protobuf::FieldOptions::CORD) { + type_url_storage = reflect->GetString(value, type_url_field); + type_url = type_url_storage; + } else { + type_url = reflect->GetStringView(value, type_url_field); + } + return ProtoValue::Create(value_factory, type_url, + reflect->GetCord(value, value_field)); +} + +absl::StatusOr> AnyMessageMoveConverter( + ValueFactory& value_factory, google::protobuf::Message&& value) { + // We currently do nothing special for moving. + return AnyMessageCopyConverter(value_factory, value); +} + +absl::StatusOr> AnyMessageBorrowConverter( + Owner& owner, ValueFactory& value_factory, + const google::protobuf::Message& value) { + // We currently do nothing special for borrowing. + return AnyMessageCopyConverter(value_factory, value); +} + ABSL_CONST_INIT absl::once_flag proto_value_once; ABSL_CONST_INIT DynamicMessageConverter dynamic_message_converters[] = { {"google.protobuf.Duration", DurationMessageCopyConverter, @@ -1348,6 +1418,8 @@ ABSL_CONST_INIT DynamicMessageConverter dynamic_message_converters[] = { ListValueMessageMoveConverter, ListValueMessageBorrowConverter}, {"google.protobuf.Value", ValueMessageCopyConverter, ValueMessageMoveConverter, ValueMessageBorrowConverter}, + {"google.protobuf.Any", AnyMessageCopyConverter, AnyMessageMoveConverter, + AnyMessageBorrowConverter}, }; DynamicMessageConverter* dynamic_message_converters_begin() { @@ -1467,6 +1539,171 @@ absl::StatusOr> ProtoValue::Create( } } +namespace { + +template +absl::StatusOr UnpackTo(const absl::Cord& cord) { + T proto; + if (ABSL_PREDICT_FALSE(!proto.ParseFromCord(cord))) { + return absl::InvalidArgumentError( + absl::StrCat("failed to unpack google.protobuf.Any as ", + T::descriptor()->full_name())); + } + return proto; +} + +} // namespace + +absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, + absl::string_view type_url, + const absl::Cord& payload) { + if (type_url.empty()) { + return value_factory.CreateErrorValue( + absl::UnknownError("invalid empty type URL in google.protobuf.Any")); + } + auto type_name = absl::StripPrefix(type_url, "type.googleapis.com/"); + CEL_ASSIGN_OR_RETURN(auto type, + value_factory.type_manager().ResolveType(type_name)); + if (ABSL_PREDICT_FALSE(!type.has_value())) { + return value_factory.CreateErrorValue( + absl::NotFoundError(absl::StrCat("type not found: ", type_url))); + } + switch ((*type)->kind()) { + case Kind::kAny: + ABSL_DCHECK(type_name == "google.protobuf.Any") << type_name; + // google.protobuf.Any + // + // We refuse google.protobuf.Any wrapped in google.protobuf.Any. + return absl::InvalidArgumentError( + "refusing to unpack google.protobuf.Any to google.protobuf.Any"); + case Kind::kStruct: { + if (!ProtoStructType::Is(**type)) { + return absl::FailedPreconditionError( + "google.protobuf.Any can only be unpacked to protocol " + "buffer message based structs"); + } + const auto& struct_type = (*type)->As(); + const auto* prototype = + struct_type.factory_->GetPrototype(struct_type.descriptor_); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return absl::InternalError(absl::StrCat( + "protocol buffer message factory does not have prototype for ", + struct_type.DebugString())); + } + auto proto = absl::WrapUnique(prototype->New()); + if (ABSL_PREDICT_FALSE(!proto->ParseFromCord(payload))) { + return absl::InvalidArgumentError( + absl::StrCat("failed to unpack google.protobuf.Any to ", + struct_type.DebugString())); + } + return ProtoStructValue::Create(value_factory, std::move(*proto)); + } + case Kind::kWrapper: { + switch ((*type)->As().wrapped()->kind()) { + case Kind::kBool: { + // google.protobuf.BoolValue + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, proto); + } + case Kind::kInt: { + // google.protobuf.{Int32Value,Int64Value} + if (type_name == "google.protobuf.Int32Value") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + if (type_name == "google.protobuf.Int64Value") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + } break; + case Kind::kUint: { + // google.protobuf.{UInt32Value,UInt64Value} + if (type_name == "google.protobuf.UInt32Value") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + if (type_name == "google.protobuf.UInt64Value") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + } break; + case Kind::kDouble: { + // google.protobuf.{FloatValue,DoubleValue} + if (type_name == "google.protobuf.FloatValue") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + if (type_name == "google.protobuf.DoubleValue") { + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + } break; + case Kind::kBytes: { + // google.protobuf.BytesValue + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + case Kind::kString: { + // google.protobuf.StringValue + CEL_ASSIGN_OR_RETURN( + auto proto, UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + default: + ABSL_UNREACHABLE(); + } + } break; + case Kind::kList: { + // google.protobuf.ListValue + ABSL_DCHECK(type_name == "google.protobuf.ListValue") << type_name; + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + case Kind::kMap: { + // google.protobuf.Struct + ABSL_DCHECK(type_name == "google.protobuf.Struct") << type_name; + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + case Kind::kDyn: { + // google.protobuf.Value + ABSL_DCHECK(type_name == "google.protobuf.Value") << type_name; + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, std::move(proto)); + } + case Kind::kDuration: { + // google.protobuf.Duration + ABSL_DCHECK(type_name == "google.protobuf.Duration") << type_name; + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, proto); + } + case Kind::kTimestamp: { + // google.protobuf.Timestamp + ABSL_DCHECK(type_name == "google.protobuf.Timestamp") << type_name; + CEL_ASSIGN_OR_RETURN(auto proto, + UnpackTo(payload)); + return Create(value_factory, proto); + } + default: + break; + } + return absl::UnimplementedError( + absl::StrCat("google.protobuf.Any unpacking to ", (*type)->DebugString(), + " is not implemented")); +} + namespace protobuf_internal { absl::StatusOr> CreateBorrowedListValue( diff --git a/extensions/protobuf/value.h b/extensions/protobuf/value.h index be70fec67..fd22129a1 100644 --- a/extensions/protobuf/value.h +++ b/extensions/protobuf/value.h @@ -19,12 +19,14 @@ #include #include +#include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/attributes.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/time/time.h" #include "base/handle.h" #include "base/owner.h" @@ -122,6 +124,12 @@ class ProtoValue final { template using NotJsonMessage = std::negation>; + template + using AnyMessage = std::is_same>; + + template + using NotAnyMessage = std::negation>; + public: // Create a new EnumValue from a generated protocol buffer enum. template @@ -148,7 +156,7 @@ class ProtoValue final { static std::enable_if_t< std::conjunction_v, NotDurationMessage, NotTimestampMessage, NotWrapperMessage, - NotJsonMessage>, + NotJsonMessage, NotAnyMessage>, absl::StatusOr>> Create(ValueFactory& value_factory, T&& value) { return ProtoStructValue::Create(value_factory, std::forward(value)); @@ -158,7 +166,7 @@ class ProtoValue final { static std::enable_if_t< std::conjunction_v, NotDurationMessage, NotTimestampMessage, NotWrapperMessage, - NotJsonMessage>, + NotJsonMessage, NotAnyMessage>, absl::StatusOr>> CreateBorrowed(ValueFactory& value_factory, const T& value ABSL_ATTRIBUTE_LIFETIME_BOUND) { @@ -234,6 +242,15 @@ class ProtoValue final { return value_factory.CreateUintValue(value.value()); } + // Create a new Value from google.protobuf.Any. + static absl::StatusOr> Create( + ValueFactory& value_factory, const google::protobuf::Any& value) { + return Create(value_factory, value.type_url(), absl::Cord(value.value())); + } + static absl::StatusOr> Create(ValueFactory& value_factory, + absl::string_view type_url, + const absl::Cord& payload); + static absl::StatusOr> Create( ValueFactory& value_factory, google::protobuf::ListValue value); diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc index a911de316..6909b0c1c 100644 --- a/extensions/protobuf/value_test.cc +++ b/extensions/protobuf/value_test.cc @@ -16,15 +16,19 @@ #include +#include "google/protobuf/api.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" #include "base/internal/memory_manager_testing.h" #include "base/testing/value_matchers.h" #include "base/type_factory.h" #include "base/type_manager.h" #include "base/value_factory.h" #include "extensions/protobuf/enum_value.h" +#include "extensions/protobuf/internal/descriptors.h" #include "extensions/protobuf/internal/testing.h" +#include "extensions/protobuf/struct_value.h" #include "extensions/protobuf/type_provider.h" #include "internal/testing.h" #include "proto/test/v1/proto3/test_all_types.pb.h" @@ -597,6 +601,212 @@ TEST_P(ProtoValueTest, StaticRValueStruct) { IsOkAndHolds(Optional(ValueOf(value_factory, true)))); } +enum class ProtoValueAnyTestRunner { + kGenerated, + kCustom, +}; + +template +void AbslStringify(S& sink, ProtoValueAnyTestRunner value) { + switch (value) { + case ProtoValueAnyTestRunner::kGenerated: + sink.Append("Generated"); + break; + case ProtoValueAnyTestRunner::kCustom: + sink.Append("Custom"); + break; + } +} + +class ProtoValueAnyTest : public ProtoTest { + protected: + template + void Run( + const T& message, + absl::FunctionRef&)> tester) { + google::protobuf::Any any; + ASSERT_TRUE(any.PackFrom(message)); + switch (std::get<1>(GetParam())) { + case ProtoValueAnyTestRunner::kGenerated: { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value, + ProtoValue::Create(value_factory, message)); + tester(value_factory, value); + return; + } + case ProtoValueAnyTestRunner::kCustom: { + } + protobuf_internal::WithCustomDescriptorPool( + memory_manager(), any, *T::descriptor(), + [&](TypeProvider& type_provider, + const google::protobuf::Message& custom_message) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN( + auto value, + ProtoValue::Create(value_factory, custom_message)); + tester(value_factory, value); + }); + return; + } + } + + template + void Run(const T& message, + absl::FunctionRef&)> tester) { + Run(message, [&](ValueFactory& value_factory, const Handle& value) { + tester(value); + }); + } +}; + +TEST_P(ProtoValueAnyTest, AnyBoolWrapper) { + google::protobuf::BoolValue payload; + payload.set_value(true); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), true); + }); +} + +TEST_P(ProtoValueAnyTest, AnyInt32Wrapper) { + google::protobuf::Int32Value payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyInt64Wrapper) { + google::protobuf::Int64Value payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyUInt32Wrapper) { + google::protobuf::UInt32Value payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyUInt64Wrapper) { + google::protobuf::UInt64Value payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyFloatWrapper) { + google::protobuf::FloatValue payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyDoubleWrapper) { + google::protobuf::DoubleValue payload; + payload.set_value(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), 1); + }); +} + +TEST_P(ProtoValueAnyTest, AnyBytesWrapper) { + google::protobuf::BytesValue payload; + payload.set_value("foo"); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->ToString(), "foo"); + }); +} + +TEST_P(ProtoValueAnyTest, AnyStringWrapper) { + google::protobuf::StringValue payload; + payload.set_value("foo"); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->ToString(), "foo"); + }); +} + +TEST_P(ProtoValueAnyTest, AnyDuration) { + google::protobuf::Duration payload; + payload.set_seconds(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), absl::Seconds(1)); + }); +} + +TEST_P(ProtoValueAnyTest, AnyTimestamp) { + google::protobuf::Timestamp payload; + payload.set_seconds(1); + Run(payload, [](const Handle& value) { + EXPECT_EQ(value.As()->value(), + absl::UnixEpoch() + absl::Seconds(1)); + }); +} + +TEST_P(ProtoValueAnyTest, AnyValue) { + google::protobuf::Value payload; + payload.set_bool_value(true); + Run(payload, [](const Handle& value) { + EXPECT_TRUE(value.As()->value()); + }); +} + +TEST_P(ProtoValueAnyTest, AnyListValue) { + google::protobuf::ListValue payload; + payload.add_values()->set_bool_value(true); + Run(payload, [](ValueFactory& value_factory, const Handle& value) { + ASSERT_TRUE(value->Is()); + EXPECT_EQ(value.As()->size(), 1); + ASSERT_OK_AND_ASSIGN( + auto element, + value->As().Get(ListValue::GetContext(value_factory), 0)); + ASSERT_TRUE(element->Is()); + EXPECT_TRUE(element.As()->value()); + }); +} + +TEST_P(ProtoValueAnyTest, AnyMessage) { + google::protobuf::Struct payload; + payload.mutable_fields()->insert( + {"foo", google::protobuf::Value::default_instance()}); + Run(payload, [](ValueFactory& value_factory, const Handle& value) { + ASSERT_TRUE(value->Is()); + EXPECT_EQ(value.As()->size(), 1); + ASSERT_OK_AND_ASSIGN(auto key, value_factory.CreateStringValue("foo")); + ASSERT_OK_AND_ASSIGN( + auto field, + value->As().Get(MapValue::GetContext(value_factory), key)); + ASSERT_TRUE(field.has_value()); + ASSERT_TRUE((*field)->Is()); + }); +} + +TEST_P(ProtoValueAnyTest, AnyStruct) { + google::protobuf::Api payload; + payload.set_name("foo"); + Run(payload, [&payload](const Handle& value) { + ASSERT_TRUE(value->Is()); + EXPECT_EQ(value->As().value()->SerializeAsString(), + payload.SerializeAsString()); + }); +} + +INSTANTIATE_TEST_SUITE_P( + ProtoValueAnyTest, ProtoValueAnyTest, + testing::Combine(cel::base_internal::MemoryManagerTestModeAll(), + testing::Values(ProtoValueAnyTestRunner::kGenerated, + ProtoValueAnyTestRunner::kCustom))); + TEST_P(ProtoValueTest, StaticWrapperTypes) { TypeFactory type_factory(memory_manager()); ProtoTypeProvider type_provider; From dc69778a0bc395a3bcf7dbe5596607968c76d4b4 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 15 May 2023 19:38:55 +0000 Subject: [PATCH 271/303] Remove unused planning arena. PiperOrigin-RevId: 532193349 --- eval/compiler/flat_expr_builder.cc | 24 +++++++----------------- eval/eval/const_value_step_test.cc | 2 +- eval/eval/evaluator_core.h | 8 ++------ eval/eval/logic_step_test.cc | 3 +-- eval/eval/shadowable_value_step_test.cc | 2 +- eval/eval/ternary_step_test.cc | 3 +-- 6 files changed, 13 insertions(+), 29 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 4f87341ad..2ef8c42b3 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -279,7 +279,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { reference_map, google::api::expr::runtime::ExecutionPath* path, google::api::expr::runtime::BuilderWarnings* warnings, - google::protobuf::Arena* arena, PlannerContext::ProgramTree& program_tree, + PlannerContext::ProgramTree& program_tree, PlannerContext& extension_context) : resolver_(resolver), execution_path_(path), @@ -295,7 +295,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { builder_warnings_(warnings), regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), - arena_(arena), program_tree_(program_tree), extension_context_(extension_context) {} @@ -870,7 +869,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { const absl::flat_hash_map* const reference_map_; - google::protobuf::Arena* const arena_; PlannerContext::ProgramTree& program_tree_; PlannerContext extension_context_; }; @@ -1327,18 +1325,16 @@ FlatExprBuilder::CreateExpressionImpl( effective_expr = &const_fold_buffer; } - auto arena = std::make_unique(); - std::vector> optimizers; for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { CEL_ASSIGN_OR_RETURN(optimizers.emplace_back(), optimizer_factory(extension_context, ast_impl)); } - FlatExprVisitor visitor( - resolver, options_, constant_idents, - enable_comprehension_vulnerability_check_, enable_regex_precompilation_, - optimizers, &ast_impl.reference_map(), &execution_path, &warnings_builder, - arena.get(), program_tree, extension_context); + FlatExprVisitor visitor(resolver, options_, constant_idents, + enable_comprehension_vulnerability_check_, + enable_regex_precompilation_, optimizers, + &ast_impl.reference_map(), &execution_path, + &warnings_builder, program_tree, extension_context); AstTraverse(effective_expr, &ast_impl.source_info(), &visitor); @@ -1346,15 +1342,9 @@ FlatExprBuilder::CreateExpressionImpl( return visitor.progress_status(); } - if (arena->SpaceUsed() == 0) { - // No space in the arena was used, delete it. - arena.reset(); - } - std::unique_ptr expression_impl = std::make_unique(std::move(execution_path), - GetTypeRegistry(), options_, - std::move(arena)); + GetTypeRegistry(), options_); if (warnings != nullptr) { *warnings = std::move(warnings_builder).warnings(); diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index bd2a247ec..c5f5f6aff 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -38,7 +38,7 @@ absl::StatusOr RunConstantExpression(const Expr* expr, CelExpressionFlatImpl impl(std::move(path), &google::api::expr::runtime::TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); google::api::expr::runtime::Activation activation; diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index b10bb4efb..9af2a3726 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -286,10 +286,8 @@ class CelExpressionFlatImpl : public CelExpression { // bound). CelExpressionFlatImpl(ExecutionPath path, const CelTypeRegistry* type_registry, - const cel::RuntimeOptions& options, - std::unique_ptr arena = nullptr) - : arena_(std::move(arena)), - path_(std::move(path)), + const cel::RuntimeOptions& options) + : path_(std::move(path)), type_registry_(*type_registry), options_(options) {} @@ -321,8 +319,6 @@ class CelExpressionFlatImpl : public CelExpression { CelEvaluationListener callback) const override; private: - // Arena used while building the expression, must live as long. - const std::unique_ptr arena_; const ExecutionPath path_; const CelTypeRegistry& type_registry_; cel::RuntimeOptions options_; diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 471351206..a76264fd1 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -48,8 +48,7 @@ class LogicStepTest : public testing::TestWithParam { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; activation.InsertValue("name0", arg0); diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc index 9bb32d1c2..cd65883ab 100644 --- a/eval/eval/shadowable_value_step_test.cc +++ b/eval/eval/shadowable_value_step_test.cc @@ -34,7 +34,7 @@ absl::StatusOr RunShadowableExpression(std::string identifier, path.push_back(std::move(step)); CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), - cel::RuntimeOptions{}, {}); + cel::RuntimeOptions{}); return impl.Evaluate(activation, arena); } diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index 524d7c84a..2d983a132 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -59,8 +59,7 @@ class LogicStepTest : public testing::TestWithParam { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } - CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options, - {}); + CelExpressionFlatImpl impl(std::move(path), &TestTypeRegistry(), options); Activation activation; std::string value("test"); From 7f758e3310aaffe427edd6139c27e407da4d1a76 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 15 May 2023 21:26:45 +0000 Subject: [PATCH 272/303] Remove usage of `Reflection::GetStringView` PiperOrigin-RevId: 532225932 --- base/value_factory.cc | 12 +++ base/value_factory.h | 2 + extensions/protobuf/BUILD | 2 +- extensions/protobuf/internal/BUILD | 15 +++- extensions/protobuf/internal/reflection.cc | 90 +++++++++++++++++++ extensions/protobuf/internal/reflection.h | 50 +++++++++++ extensions/protobuf/internal/wrappers.cc | 43 ++------- extensions/protobuf/internal/wrappers.h | 6 +- extensions/protobuf/internal/wrappers_test.cc | 5 +- extensions/protobuf/struct_value.cc | 68 +++----------- extensions/protobuf/value.cc | 57 ++++-------- 11 files changed, 211 insertions(+), 139 deletions(-) create mode 100644 extensions/protobuf/internal/reflection.cc create mode 100644 extensions/protobuf/internal/reflection.h diff --git a/base/value_factory.cc b/base/value_factory.cc index c05c7dd51..a6d1075f6 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -23,6 +23,7 @@ #include "absl/status/statusor.h" #include "base/handle.h" #include "base/value.h" +#include "base/values/string_value.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/utf8.h" @@ -158,6 +159,17 @@ Handle ValueFactory::CreateUncheckedStringValue( std::move(value)); } +Handle ValueFactory::CreateUncheckedStringValue(absl::Cord value) { + // Avoid persisting empty strings which may have underlying storage after + // mutating. + if (value.empty()) { + return GetEmptyStringValue(); + } + + return HandleFactory::Make( + std::move(value)); +} + absl::StatusOr> ValueFactory::CreateStringValue( absl::Cord value) { if (value.empty()) { diff --git a/base/value_factory.h b/base/value_factory.h index 2e02a0d21..4538aa717 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -196,6 +196,8 @@ class ValueFactory final { // already been validated as utf-8. Handle CreateUncheckedStringValue(std::string value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + Handle CreateUncheckedStringValue(absl::Cord value) + ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::StatusOr> CreateStringValue(absl::Cord value) ABSL_ATTRIBUTE_LIFETIME_BOUND; diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 2cda3c67d..785bef23c 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -158,6 +158,7 @@ cc_library( "//eval/public:message_wrapper", "//eval/public/structs:proto_message_type_adapter", "//extensions/protobuf/internal:map_reflection", + "//extensions/protobuf/internal:reflection", "//extensions/protobuf/internal:time", "//extensions/protobuf/internal:wrappers", "//internal:casts", @@ -174,7 +175,6 @@ cc_library( "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index 76a83e684..e7f4e2ff9 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -44,6 +44,20 @@ cc_library( deps = ["@com_google_protobuf//:protobuf"], ) +cc_library( + name = "reflection", + srcs = ["reflection.cc"], + hdrs = ["reflection.h"], + deps = [ + "//base:handle", + "//base:owner", + "//base:value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "testing", testonly = True, @@ -94,7 +108,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/internal/reflection.cc b/extensions/protobuf/internal/reflection.cc new file mode 100644 index 000000000..15adffb76 --- /dev/null +++ b/extensions/protobuf/internal/reflection.cc @@ -0,0 +1,90 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/internal/reflection.h" + +#include +#include + +#include "google/protobuf/descriptor.pb.h" + +namespace cel::extensions::protobuf_internal { + +namespace { + +bool IsCordField(const google::protobuf::FieldDescriptor& field) { + return !field.is_extension() && + field.options().ctype() == google::protobuf::FieldOptions::CORD; +} + +} // namespace + +absl::StatusOr> GetStringField( + ValueFactory& value_factory, const google::protobuf::Message& message, + const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + if (IsCordField(*field)) { + return value_factory.CreateUncheckedStringValue( + reflection->GetCord(message, field)); + } + return value_factory.CreateUncheckedStringValue( + reflection->GetString(message, field)); +} + +absl::StatusOr> GetBorrowedStringField( + ValueFactory& value_factory, Owner owner, + const google::protobuf::Message& message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + if (IsCordField(*field)) { + return value_factory.CreateUncheckedStringValue( + reflection->GetCord(message, field)); + } + std::string scratch; + const std::string& reference = + reflection->GetStringReference(message, field, &scratch); + if (&reference == &scratch) { + return value_factory.CreateUncheckedStringValue(std::move(scratch)); + } + return value_factory.CreateBorrowedStringValue(std::move(owner), + absl::string_view(reference)); +} + +absl::StatusOr> GetBytesField( + ValueFactory& value_factory, const google::protobuf::Message& message, + const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + if (IsCordField(*field)) { + return value_factory.CreateBytesValue(reflection->GetCord(message, field)); + } + return value_factory.CreateBytesValue(reflection->GetString(message, field)); +} + +absl::StatusOr> GetBorrowedBytesField( + ValueFactory& value_factory, Owner owner, + const google::protobuf::Message& message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + if (IsCordField(*field)) { + return value_factory.CreateBytesValue(reflection->GetCord(message, field)); + } + std::string scratch; + const std::string& reference = + reflection->GetStringReference(message, field, &scratch); + if (&reference == &scratch) { + return value_factory.CreateBytesValue(std::move(scratch)); + } + return value_factory.CreateBorrowedBytesValue(std::move(owner), + absl::string_view(reference)); +} + +} // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/reflection.h b/extensions/protobuf/internal/reflection.h new file mode 100644 index 000000000..c7cbec629 --- /dev/null +++ b/extensions/protobuf/internal/reflection.h @@ -0,0 +1,50 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_REFLECTION_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_REFLECTION_H_ + +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "base/handle.h" +#include "base/owner.h" +#include "base/value_factory.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::extensions::protobuf_internal { + +absl::StatusOr> GetStringField( + ValueFactory& value_factory, const google::protobuf::Message& message, + const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) + ABSL_ATTRIBUTE_NONNULL(); + +absl::StatusOr> GetBorrowedStringField( + ValueFactory& value_factory, Owner owner, + const google::protobuf::Message& message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) ABSL_ATTRIBUTE_NONNULL(); + +absl::StatusOr> GetBytesField( + ValueFactory& value_factory, const google::protobuf::Message& message, + const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) + ABSL_ATTRIBUTE_NONNULL(); + +absl::StatusOr> GetBorrowedBytesField( + ValueFactory& value_factory, Owner owner, + const google::protobuf::Message& message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) ABSL_ATTRIBUTE_NONNULL(); + +} // namespace cel::extensions::protobuf_internal + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_REFLECTION_H_ diff --git a/extensions/protobuf/internal/wrappers.cc b/extensions/protobuf/internal/wrappers.cc index 0315f0b7f..de075fb7c 100644 --- a/extensions/protobuf/internal/wrappers.cc +++ b/extensions/protobuf/internal/wrappers.cc @@ -15,7 +15,6 @@ #include "extensions/protobuf/internal/wrappers.h" #include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "internal/casts.h" @@ -36,7 +35,7 @@ absl::StatusOr

UnwrapValueProto(const google::protobuf::Message& message, } if (desc == T::descriptor()) { // Fast path. - return cel::internal::down_cast(message).value(); + return P(cel::internal::down_cast(message).value()); } const auto* reflect = message.GetReflection(); if (ABSL_PREDICT_FALSE(reflect == nullptr)) { @@ -123,41 +122,11 @@ absl::StatusOr UnwrapInt64ValueProto(const google::protobuf::Message& m &google::protobuf::Reflection::GetInt64); } -absl::StatusOr> -UnwrapStringValueProto(const google::protobuf::Message& message) { - const auto* desc = message.GetDescriptor(); - if (ABSL_PREDICT_FALSE(desc == nullptr)) { - return absl::InternalError( - absl::StrCat(message.GetTypeName(), " missing descriptor")); - } - if (desc == google::protobuf::StringValue::descriptor()) { - // Fast path. - return absl::string_view( - cel::internal::down_cast(message) - .value()); - } - const auto* reflect = message.GetReflection(); - if (ABSL_PREDICT_FALSE(reflect == nullptr)) { - return absl::InternalError( - absl::StrCat(message.GetTypeName(), " missing reflection")); - } - const auto* value_field = - desc->FindFieldByNumber(google::protobuf::StringValue::kValueFieldNumber); - if (ABSL_PREDICT_FALSE(value_field == nullptr)) { - return absl::InternalError( - absl::StrCat(message.GetTypeName(), " missing value field descriptor")); - } - if (ABSL_PREDICT_FALSE(value_field->cpp_type() != - google::protobuf::FieldDescriptor::CPPTYPE_STRING)) { - return absl::InternalError(absl::StrCat( - message.GetTypeName(), - " has unexpected value field type: ", value_field->cpp_type_name())); - } - if (value_field->options().ctype() == google::protobuf::FieldOptions::CORD && - !value_field->is_extension()) { - return reflect->GetCord(message, value_field); - } - return reflect->GetStringView(message, value_field); +absl::StatusOr UnwrapStringValueProto( + const google::protobuf::Message& message) { + return UnwrapValueProto( + message, google::protobuf::FieldDescriptor::CPPTYPE_STRING, + &google::protobuf::Reflection::GetCord); } absl::StatusOr UnwrapUIntValueProto(const google::protobuf::Message& message) { diff --git a/extensions/protobuf/internal/wrappers.h b/extensions/protobuf/internal/wrappers.h index c0c5db15a..00ea556f1 100644 --- a/extensions/protobuf/internal/wrappers.h +++ b/extensions/protobuf/internal/wrappers.h @@ -17,8 +17,6 @@ #include "absl/status/statusor.h" #include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "google/protobuf/message.h" namespace cel::extensions::protobuf_internal { @@ -38,8 +36,8 @@ absl::StatusOr UnwrapInt32ValueProto(const google::protobuf::Message& m absl::StatusOr UnwrapInt64ValueProto(const google::protobuf::Message& message); -absl::StatusOr> -UnwrapStringValueProto(const google::protobuf::Message& message); +absl::StatusOr UnwrapStringValueProto( + const google::protobuf::Message& message); absl::StatusOr UnwrapUIntValueProto(const google::protobuf::Message& message); diff --git a/extensions/protobuf/internal/wrappers_test.cc b/extensions/protobuf/internal/wrappers_test.cc index 0a4a88802..7cdc3c80a 100644 --- a/extensions/protobuf/internal/wrappers_test.cc +++ b/extensions/protobuf/internal/wrappers_test.cc @@ -25,7 +25,6 @@ namespace cel::extensions::protobuf_internal { namespace { using testing::Eq; -using testing::VariantWith; using cel::internal::IsOkAndHolds; TEST(BoolWrapper, Generated) { @@ -124,7 +123,7 @@ TEST(IntWrapper, Custom) { TEST(StringWrapper, Generated) { EXPECT_THAT(UnwrapStringValueProto(google::protobuf::StringValue()), - IsOkAndHolds(VariantWith(""))); + IsOkAndHolds(absl::Cord())); } TEST(StringWrapper, Custom) { @@ -140,7 +139,7 @@ TEST(StringWrapper, Custom) { factory.SetDelegateToGeneratedFactory(false); EXPECT_THAT(UnwrapStringValueProto(*factory.GetPrototype( pool.FindMessageTypeByName("google.protobuf.StringValue"))), - IsOkAndHolds(VariantWith(""))); + IsOkAndHolds(absl::Cord())); } TEST(UintWrapper, Generated) { diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index a2f6623e5..ee70b8ab2 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -23,7 +23,6 @@ #include #include -#include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" @@ -35,7 +34,6 @@ #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" -#include "absl/types/variant.h" #include "base/handle.h" #include "base/memory.h" #include "base/types/struct_type.h" @@ -55,6 +53,7 @@ #include "eval/public/structs/proto_message_type_adapter.h" #include "extensions/protobuf/enum_type.h" #include "extensions/protobuf/internal/map_reflection.h" +#include "extensions/protobuf/internal/reflection.h" #include "extensions/protobuf/internal/time.h" #include "extensions/protobuf/internal/wrappers.h" #include "extensions/protobuf/memory_manager.h" @@ -96,16 +95,6 @@ namespace protobuf_internal { namespace { -struct DebugStringFromStringWrapperVisitor final { - std::string operator()(absl::string_view value) const { - return StringValue::DebugString(value); - } - - std::string operator()(const absl::Cord& value) const { - return StringValue::DebugString(value); - } -}; - class HeapDynamicParsedProtoStructValue : public DynamicParsedProtoStructValue { public: HeapDynamicParsedProtoStructValue(Handle type, @@ -217,8 +206,7 @@ std::string StringValueDebugStringFromProto(const google::protobuf::Message& mes if (ABSL_PREDICT_FALSE(!value_or_status.ok())) { return std::string("**google.protobuf.StringValue**"); } - return absl::visit(protobuf_internal::DebugStringFromStringWrapperVisitor{}, - *value_or_status); + return StringValue::DebugString(*value_or_status); } std::string UintValueDebugStringFromProto(const google::protobuf::Message& message) { @@ -1194,14 +1182,8 @@ class ParsedProtoListValue const auto& field = fields_.Get(static_cast(index), scratch.get()); CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(field)); - if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateBorrowedStringValue( - owner_from_this(), absl::get(wrapped)); - } else { - ABSL_ASSERT(absl::holds_alternative(wrapped)); - return context.value_factory().CreateStringValue( - absl::get(std::move(wrapped))); - } + return context.value_factory().CreateUncheckedStringValue( + std::move(wrapped)); } private: @@ -1658,14 +1640,8 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto( proto_value.GetMessageValue())); - if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateBorrowedStringValue( - owner_from_this(), absl::get(wrapped)); - } else { - ABSL_ASSERT(absl::holds_alternative(wrapped)); - return context.value_factory().CreateStringValue( - absl::get(std::move(wrapped))); - } + return context.value_factory().CreateUncheckedStringValue( + std::move(wrapped)); } case Kind::kUint: { // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL @@ -2533,14 +2509,9 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateBoolValue( reflect.GetBool(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_STRING: - if (field_desc.options().ctype() == google::protobuf::FieldOptions::CORD && - !field_desc.is_extension()) { - return context.value_factory().CreateStringValue( - reflect.GetCord(value(), &field_desc)); - } else { - return context.value_factory().CreateBorrowedStringValue( - owner_from_this(), reflect.GetStringView(value(), &field_desc)); - } + return protobuf_internal::GetBorrowedStringField( + context.value_factory(), owner_from_this(), value(), &reflect, + &field_desc); case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: @@ -2619,14 +2590,8 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( auto wrapped, protobuf_internal::UnwrapStringValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); - if (absl::holds_alternative(wrapped)) { - return context.value_factory().CreateBorrowedStringValue( - owner_from_this(), absl::get(wrapped)); - } else { - ABSL_ASSERT(absl::holds_alternative(wrapped)); - return context.value_factory().CreateStringValue( - absl::get(std::move(wrapped))); - } + return context.value_factory().CreateUncheckedStringValue( + std::move(wrapped)); } case Kind::kUint: { CEL_ASSIGN_OR_RETURN( @@ -2649,14 +2614,9 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( ABSL_UNREACHABLE(); } case google::protobuf::FieldDescriptor::TYPE_BYTES: - if (field_desc.options().ctype() == google::protobuf::FieldOptions::CORD && - !field_desc.is_extension()) { - return context.value_factory().CreateBytesValue( - reflect.GetCord(value(), &field_desc)); - } else { - return context.value_factory().CreateBorrowedBytesValue( - owner_from_this(), reflect.GetStringView(value(), &field_desc)); - } + return protobuf_internal::GetBorrowedBytesField( + context.value_factory(), owner_from_this(), value(), &reflect, + &field_desc); case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type->kind()) { case Kind::kNullType: diff --git a/extensions/protobuf/value.cc b/extensions/protobuf/value.cc index ac2e6b2ec..f28f871ee 100644 --- a/extensions/protobuf/value.cc +++ b/extensions/protobuf/value.cc @@ -22,9 +22,9 @@ #include #include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/base/macros.h" @@ -37,11 +37,11 @@ #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/time/time.h" -#include "absl/types/variant.h" #include "base/handle.h" #include "base/value.h" #include "base/values/list_value.h" #include "base/values/map_value.h" +#include "extensions/protobuf/internal/reflection.h" #include "extensions/protobuf/internal/time.h" #include "extensions/protobuf/internal/wrappers.h" #include "extensions/protobuf/memory_manager.h" @@ -52,18 +52,6 @@ namespace cel::extensions { namespace { -struct CreateStringValueFromProtoVisitor final { - ValueFactory& value_factory; - - absl::StatusOr> operator()(absl::string_view value) const { - return value_factory.CreateStringValue(value); - } - - absl::StatusOr> operator()(absl::Cord value) const { - return value_factory.CreateStringValue(std::move(value)); - } -}; - void AppendJsonValueDebugString(std::string& out, const google::protobuf::Value& value); @@ -788,16 +776,14 @@ absl::StatusOr> StringValueMessageCopyConverter( ValueFactory& value_factory, const google::protobuf::Message& value) { CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(value)); - return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, - std::move(wrapped)); + return value_factory.CreateUncheckedStringValue(std::move(wrapped)); } absl::StatusOr> StringValueMessageMoveConverter( ValueFactory& value_factory, google::protobuf::Message&& value) { CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(value)); - return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, - std::move(wrapped)); + return value_factory.CreateUncheckedStringValue(std::move(wrapped)); } absl::StatusOr> StringValueMessageBorrowConverter( @@ -805,8 +791,7 @@ absl::StatusOr> StringValueMessageBorrowConverter( const google::protobuf::Message& value) { CEL_ASSIGN_OR_RETURN(auto wrapped, protobuf_internal::UnwrapStringValueProto(value)); - return absl::visit(CreateStringValueFromProtoVisitor{value_factory}, - std::move(wrapped)); + return value_factory.CreateUncheckedStringValue(std::move(wrapped)); } absl::StatusOr> UInt32ValueMessageCopyConverter( @@ -1166,8 +1151,8 @@ absl::StatusOr> ValueMessageCopyConverter( return value_factory.CreateDoubleValue( reflect->GetDouble(value, field_desc)); case google::protobuf::Value::kStringValueFieldNumber: - return value_factory.CreateStringValue( - reflect->GetStringView(value, field_desc)); + return protobuf_internal::GetStringField(value_factory, value, reflect, + field_desc); case google::protobuf::Value::kListValueFieldNumber: return ListValueMessageCopyConverter( value_factory, reflect->GetMessage(value, field_desc)); @@ -1212,8 +1197,8 @@ absl::StatusOr> ValueMessageMoveConverter( return value_factory.CreateDoubleValue( reflect->GetDouble(value, field_desc)); case google::protobuf::Value::kStringValueFieldNumber: - return value_factory.CreateStringValue( - reflect->GetStringView(value, field_desc)); + return protobuf_internal::GetStringField(value_factory, value, reflect, + field_desc); case google::protobuf::Value::kListValueFieldNumber: return ListValueMessageMoveConverter( value_factory, @@ -1261,8 +1246,8 @@ absl::StatusOr> ValueMessageBorrowConverter( return value_factory.CreateDoubleValue( reflect->GetDouble(value, field_desc)); case google::protobuf::Value::kStringValueFieldNumber: - return value_factory.CreateBorrowedStringValue( - std::move(owner), reflect->GetStringView(value, field_desc)); + return protobuf_internal::GetBorrowedStringField( + value_factory, std::move(owner), value, reflect, field_desc); case google::protobuf::Value::kListValueFieldNumber: return ListValueMessageBorrowConverter( owner, value_factory, reflect->GetMessage(value, field_desc)); @@ -1309,8 +1294,8 @@ absl::StatusOr> ValueMessageOwnConverter( return value_factory.CreateDoubleValue( reflect->GetDouble(*value, field_desc)); case google::protobuf::Value::kStringValueFieldNumber: - return value_factory.CreateStringValue( - reflect->GetStringView(*value, field_desc)); + return protobuf_internal::GetStringField(value_factory, *value, reflect, + field_desc); case google::protobuf::Value::kListValueFieldNumber: return ListValueMessageCopyConverter( value_factory, reflect->GetMessage(*value, field_desc)); @@ -1362,17 +1347,11 @@ absl::StatusOr> AnyMessageCopyConverter( return absl::InvalidArgumentError( "value field descriptor has unexpected type"); } - std::string type_url_storage; - absl::string_view type_url; - if (!type_url_field->is_extension() && - type_url_field->options().ctype() == google::protobuf::FieldOptions::CORD) { - type_url_storage = reflect->GetString(value, type_url_field); - type_url = type_url_storage; - } else { - type_url = reflect->GetStringView(value, type_url_field); - } - return ProtoValue::Create(value_factory, type_url, - reflect->GetCord(value, value_field)); + std::string type_url; + return ProtoValue::Create( + value_factory, + reflect->GetStringReference(value, type_url_field, &type_url), + reflect->GetCord(value, value_field)); } absl::StatusOr> AnyMessageMoveConverter( From cd56747694af22ea49816a0e2b3c1d6ff4299283 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 15 May 2023 22:18:03 +0000 Subject: [PATCH 273/303] Refactor cel-type registry to support using new style enums for identifying constants at plan time. For now, clients still need to explicitly declare the set of enums that can be used as constants in expression. PiperOrigin-RevId: 532242312 --- eval/compiler/BUILD | 1 + eval/compiler/resolver.cc | 31 ++++- eval/compiler/resolver_test.cc | 8 +- eval/public/BUILD | 5 + eval/public/cel_type_registry.cc | 158 ++++++++++++++++++++++---- eval/public/cel_type_registry.h | 18 +-- eval/public/cel_type_registry_test.cc | 131 ++++++++++++++++++--- 7 files changed, 296 insertions(+), 56 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index c1e8740da..74af97aad 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -282,6 +282,7 @@ cc_library( hdrs = ["resolver.h"], deps = [ "//base:kind", + "//base:value", "//eval/internal:interop", "//eval/public:cel_type_registry", "//runtime:function_overload_reference", diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 6c60b2679..8c7803bf0 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -2,18 +2,23 @@ #include #include +#include #include #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "base/values/enum_value.h" #include "eval/internal/interop.h" +#include "eval/public/cel_type_registry.h" #include "runtime/function_registry.h" namespace google::api::expr::runtime { +using ::cel::EnumType; using ::cel::Handle; +using ::cel::MemoryManager; using ::cel::Value; using ::cel::interop_internal::CreateIntValue; @@ -43,23 +48,39 @@ Resolver::Resolver(absl::string_view container, } for (const auto& prefix : namespace_prefixes_) { - for (auto iter = type_registry->enums_map().begin(); - iter != type_registry->enums_map().end(); ++iter) { + for (auto iter = type_registry->resolveable_enums().begin(); + iter != type_registry->resolveable_enums().end(); ++iter) { absl::string_view enum_name = iter->first; if (!absl::StartsWith(enum_name, prefix)) { continue; } auto remainder = absl::StripPrefix(enum_name, prefix); - for (const auto& enumerator : iter->second) { + const Handle& enum_type = iter->second; + + absl::StatusOr> + enum_value_iter_or = + enum_type->NewConstantIterator(MemoryManager::Global()); + + // Errors are not expected from the implementation in the type registry, + // but we need to swallow the error case to avoid compiler/lint warnings. + if (!enum_value_iter_or.ok()) { + continue; + } + auto enum_value_iter = *std::move(enum_value_iter_or); + while (enum_value_iter->HasNext()) { + absl::StatusOr constant = enum_value_iter->Next(); + if (!constant.ok()) { + break; + } // "prefixes" container is ascending-ordered. As such, we will be // assigning enum reference to the deepest available. // E.g. if both a.b.c.Name and a.b.Name are available, and // we try to reference "Name" with the scope of "a.b.c", // it will be resolved to "a.b.c.Name". auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - enumerator.name); - enum_value_map_[key] = CreateIntValue(enumerator.number); + constant->name); + enum_value_map_[key] = CreateIntValue(constant->number); } } } diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 392f65b89..25de79f48 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -80,14 +80,14 @@ TEST(ResolverTest, TestFindConstantEnum) { func_registry.InternalGetRegistry(), &type_registry); auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); - EXPECT_TRUE(enum_value); - EXPECT_TRUE(enum_value->Is()); + ASSERT_TRUE(enum_value); + ASSERT_TRUE(enum_value->Is()); EXPECT_THAT(enum_value.As()->value(), Eq(1L)); enum_value = resolver.FindConstant( ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_2", -1); - EXPECT_TRUE(enum_value); - EXPECT_TRUE(enum_value->Is()); + ASSERT_TRUE(enum_value); + ASSERT_TRUE(enum_value->Is()); EXPECT_THAT(enum_value.As()->value(), Eq(2L)); } diff --git a/eval/public/BUILD b/eval/public/BUILD index eb798643b..62857817b 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -899,10 +899,13 @@ cc_library( hdrs = ["cel_type_registry.h"], deps = [ "//base:handle", + "//base:memory", + "//base:type", "//base:value", "//eval/internal:interop", "//eval/public/structs:legacy_type_provider", "//internal:no_destructor", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -919,12 +922,14 @@ cc_test( srcs = ["cel_type_registry_test.cc"], deps = [ ":cel_type_registry", + "//base:type", "//base:value", "//eval/public/structs:legacy_type_provider", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index f4580ebad..60890a0b7 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -13,6 +13,11 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" +#include "base/handle.h" +#include "base/memory.h" +#include "base/type_factory.h" +#include "base/types/enum_type.h" +#include "base/value.h" #include "eval/internal/interop.h" #include "internal/no_destructor.h" @@ -21,6 +26,9 @@ namespace google::api::expr::runtime { namespace { using cel::Handle; +using cel::MemoryManager; +using cel::TypeFactory; +using cel::UniqueRef; using cel::Value; using cel::interop_internal::CreateTypeValueFromView; @@ -42,30 +50,115 @@ const absl::node_hash_set& GetCoreTypes() { } using DescriptorSet = absl::flat_hash_set; -using EnumMap = - absl::flat_hash_map>; +using EnumMap = absl::flat_hash_map>; -void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, EnumMap& map) { +// Type factory for ref-counted type instances. +cel::TypeFactory& GetDefaultTypeFactory() { + static TypeFactory* factory = new TypeFactory(cel::MemoryManager::Global()); + return *factory; +} + +// EnumType implementation for generic enums that are defined at runtime that +// can be resolved in expressions. +// +// Note: this implementation is primarily used for inspecting the full set of +// enum constants rather than looking up constants by name or number. +class ResolveableEnumType final : public cel::EnumType { + public: + using Constant = EnumType::Constant; + using Enumerator = CelTypeRegistry::Enumerator; + + ResolveableEnumType(std::string name, std::vector enumerators) + : name_(std::move(name)), enumerators_(std::move(enumerators)) {} + + static const ResolveableEnumType& Cast(const Type& type) { + ABSL_ASSERT(Is(type)); + return static_cast(type); + } + + absl::string_view name() const override { return name_; } + + size_t constant_count() const override { return enumerators_.size(); }; + + absl::StatusOr> NewConstantIterator( + MemoryManager& memory_manager) const override { + return cel::MakeUnique(memory_manager, enumerators_); + } + + const std::vector& enumerators() const { return enumerators_; } + + absl::StatusOr> FindConstantByName( + absl::string_view name) const override; + + absl::StatusOr> FindConstantByNumber( + int64_t number) const override; + + private: + class Iterator : public EnumType::ConstantIterator { + public: + using Constant = EnumType::Constant; + + explicit Iterator(absl::Span enumerators) + : idx_(0), enumerators_(enumerators) {} + + bool HasNext() override { return idx_ < enumerators_.size(); } + + absl::StatusOr Next() override { + if (!HasNext()) { + return absl::FailedPreconditionError( + "Next() called when HasNext() false in " + "ResolveableEnumType::Iterator"); + } + int current = idx_; + idx_++; + return Constant(MakeConstantId(enumerators_[current].number), + enumerators_[current].name, enumerators_[current].number); + } + + absl::StatusOr NextName() override { + CEL_ASSIGN_OR_RETURN(Constant constant, Next()); + + return constant.name; + } + + absl::StatusOr NextNumber() override { + CEL_ASSIGN_OR_RETURN(Constant constant, Next()); + + return constant.number; + } + + private: + // The index for the next returned value. + int idx_; + absl::Span enumerators_; + }; + + // Implement EnumType. + cel::internal::TypeInfo TypeId() const override { + return cel::internal::TypeId(); + } + + std::string name_; + // TODO(issues/5): this could be indexed by name and/or number if strong + // enum typing is needed at runtime. + std::vector enumerators_; +}; + +void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, + CelTypeRegistry& registry) { std::vector enumerators; enumerators.reserve(desc->value_count()); for (int i = 0; i < desc->value_count(); i++) { enumerators.push_back({desc->value(i)->name(), desc->value(i)->number()}); } - map.insert(std::pair(desc->full_name(), std::move(enumerators))); + registry.RegisterEnum(desc->full_name(), std::move(enumerators)); } -// Portable version. Add overloads for specfic core supported enums. +// Portable version. Add overloads for specific core supported enums. template struct EnumAdderT { template void AddEnum(DescriptorSet&) {} - - template - void AddEnum(EnumMap& map) { - if constexpr (std::is_same_v) { - map["google.protobuf.NullValue"] = {{"NULL_VALUE", 0}}; - } - } }; template @@ -75,16 +168,10 @@ struct EnumAdderT()); } - - template - void AddEnum(EnumMap& map) { - const google::protobuf::EnumDescriptor* desc = google::protobuf::GetEnumDescriptor(); - AddEnumFromDescriptor(desc, map); - } }; // Enable loading the linked descriptor if using the full proto runtime. -// Otherwise, only support explcitly defined enums. +// Otherwise, only support explicitly defined enums. using EnumAdder = EnumAdderT; const absl::flat_hash_set& GetCoreEnums() { @@ -98,9 +185,31 @@ const absl::flat_hash_set& GetCoreEnums } // namespace +absl::StatusOr> +ResolveableEnumType::FindConstantByName(absl::string_view name) const { + for (const Enumerator& enumerator : enumerators_) { + if (enumerator.name == name) { + return ResolveableEnumType::Constant(MakeConstantId(enumerator.number), + enumerator.name, enumerator.number); + } + } + return absl::nullopt; +} + +absl::StatusOr> +ResolveableEnumType::FindConstantByNumber(int64_t number) const { + for (const Enumerator& enumerator : enumerators_) { + if (enumerator.number == number) { + return ResolveableEnumType::Constant(MakeConstantId(enumerator.number), + enumerator.name, enumerator.number); + } + } + return absl::nullopt; +} + CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()), enums_(GetCoreEnums()) { - EnumAdder().AddEnum(enums_map_); + RegisterEnum("google.protobuf.NullValue", {{"NULL_VALUE", 0}}); } void CelTypeRegistry::Register(std::string fully_qualified_type_name) { @@ -111,12 +220,17 @@ void CelTypeRegistry::Register(std::string fully_qualified_type_name) { void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { enums_.insert(enum_descriptor); - AddEnumFromDescriptor(enum_descriptor, enums_map_); + AddEnumFromDescriptor(enum_descriptor, *this); } void CelTypeRegistry::RegisterEnum(absl::string_view enum_name, std::vector enumerators) { - enums_map_[enum_name] = std::move(enumerators); + absl::StatusOr> result_or = + GetDefaultTypeFactory().CreateEnumType( + std::string(enum_name), std::move(enumerators)); + // For this setup, the type factory should never return an error. + result_or.IgnoreError(); + resolveable_enums_[enum_name] = std::move(result_or).value(); } std::shared_ptr diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index bf97fbd37..ae622fef9 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -7,6 +7,8 @@ #include #include "google/protobuf/descriptor.h" +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -14,6 +16,7 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "base/handle.h" +#include "base/types/enum_type.h" #include "base/value.h" #include "eval/public/structs/legacy_type_provider.h" @@ -33,7 +36,7 @@ namespace google::api::expr::runtime { // pools. class CelTypeRegistry { public: - // Internal representation for enumerators. + // Representation of an enum constant. struct Enumerator { std::string name; int64_t number; @@ -75,7 +78,7 @@ class CelTypeRegistry { std::shared_ptr GetFirstTypeProvider() const; // Find a type adapter given a fully qualified type name. - // Adapter provides a generic interface for the reflecion operations the + // Adapter provides a generic interface for the reflection operations the // interpreter needs to provide. absl::optional FindTypeAdapter( absl::string_view fully_qualified_type_name) const; @@ -92,10 +95,10 @@ class CelTypeRegistry { } // Return the registered enums configured within the type registry in the - // internal format. - const absl::flat_hash_map>& enums_map() - const { - return enums_map_; + // internal format that can be identified as int constants at plan time. + const absl::flat_hash_map>& + resolveable_enums() const { + return resolveable_enums_; } private: @@ -106,7 +109,8 @@ class CelTypeRegistry { // Set of registered enums. absl::flat_hash_set enums_; // Internal representation for enums. - absl::flat_hash_map> enums_map_; + absl::flat_hash_map> + resolveable_enums_; std::vector> type_providers_; }; diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 9b20fe7c0..59840f449 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -7,9 +7,10 @@ #include #include "google/protobuf/struct.pb.h" -#include "google/protobuf/message.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/types/enum_type.h" #include "base/values/type_value.h" #include "eval/public/structs/legacy_type_provider.h" #include "eval/testutil/test_message.pb.h" @@ -19,14 +20,21 @@ namespace google::api::expr::runtime { namespace { +using ::cel::EnumType; +using ::cel::Handle; +using ::cel::MemoryManager; using ::cel::TypeValue; using testing::AllOf; using testing::Contains; using testing::Eq; using testing::IsEmpty; using testing::Key; +using testing::Optional; using testing::Pair; +using testing::Truly; using testing::UnorderedElementsAre; +using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; class TestTypeProvider : public LegacyTypeProvider { public: @@ -50,20 +58,31 @@ class TestTypeProvider : public LegacyTypeProvider { }; MATCHER_P(MatchesEnumDescriptor, desc, "") { - const std::vector& enumerators = arg; + const Handle& enum_type = arg; - if (enumerators.size() != desc->value_count()) { + if (enum_type->constant_count() != desc->value_count()) { return false; } + auto iter_or = enum_type->NewConstantIterator(MemoryManager::Global()); + if (!iter_or.ok()) { + return false; + } + + auto iter = std::move(iter_or).value(); + for (int i = 0; i < desc->value_count(); i++) { + absl::StatusOr constant = iter->Next(); + if (!constant.ok()) { + return false; + } + const auto* value_desc = desc->value(i); - const auto& enumerator = enumerators[i]; - if (value_desc->name() != enumerator.name) { + if (value_desc->name() != constant->name) { return false; } - if (value_desc->number() != enumerator.number) { + if (value_desc->number() != constant->number) { return false; } } @@ -105,7 +124,7 @@ struct RegisterEnumDescriptorTestT< EXPECT_THAT(enum_set, Eq(expected_set)); EXPECT_THAT( - registry.enums_map(), + registry.resolveable_enums(), AllOf( Contains(Pair( "google.protobuf.NullValue", @@ -134,22 +153,98 @@ TEST(CelTypeRegistryTest, RegisterEnum) { {"TEST_ENUM_3", 30}, }); - EXPECT_THAT( - registry.enums_map(), - Contains(Pair("google.api.expr.runtime.TestMessage.TestEnum", - Contains(testing::Truly( - [](const CelTypeRegistry::Enumerator& enumerator) { - return enumerator.name == "TEST_ENUM_2" && - enumerator.number == 20; - }))))); + EXPECT_THAT(registry.resolveable_enums(), + Contains(Pair( + "google.api.expr.runtime.TestMessage.TestEnum", + testing::Truly([](const Handle& enum_type) { + auto constant = + enum_type->FindConstantByName("TEST_ENUM_2"); + return enum_type->name() == + "google.api.expr.runtime.TestMessage.TestEnum" && + constant.value()->number == 20; + })))); +} + +MATCHER_P(ConstantIntValue, x, "") { + const EnumType::Constant& constant = arg; + + return constant.number == x; +} + +MATCHER_P(ConstantName, x, "") { + const EnumType::Constant& constant = arg; + + return constant.name == x; +} + +TEST(CelTypeRegistryTest, ImplementsEnumType) { + CelTypeRegistry registry; + registry.RegisterEnum("google.api.expr.runtime.TestMessage.TestEnum", + { + {"TEST_ENUM_UNSPECIFIED", 0}, + {"TEST_ENUM_1", 10}, + {"TEST_ENUM_2", 20}, + {"TEST_ENUM_3", 30}, + }); + + ASSERT_THAT(registry.resolveable_enums(), + Contains(Key("google.api.expr.runtime.TestMessage.TestEnum"))); + + const Handle& enum_type = registry.resolveable_enums().at( + "google.api.expr.runtime.TestMessage.TestEnum"); + + EXPECT_TRUE(enum_type->Is()); + + EXPECT_THAT(enum_type->FindConstantByName("TEST_ENUM_UNSPECIFIED"), + IsOkAndHolds(Optional(ConstantIntValue(0)))); + EXPECT_THAT(enum_type->FindConstantByName("TEST_ENUM_1"), + IsOkAndHolds(Optional(ConstantIntValue(10)))); + EXPECT_THAT(enum_type->FindConstantByName("TEST_ENUM_4"), + IsOkAndHolds(Eq(absl::nullopt))); + + EXPECT_THAT(enum_type->FindConstantByNumber(20), + IsOkAndHolds(Optional(ConstantName("TEST_ENUM_2")))); + EXPECT_THAT(enum_type->FindConstantByNumber(30), + IsOkAndHolds(Optional(ConstantName("TEST_ENUM_3")))); + EXPECT_THAT(enum_type->FindConstantByNumber(42), + IsOkAndHolds(Eq(absl::nullopt))); + + std::vector names; + ASSERT_OK_AND_ASSIGN(auto iter, + enum_type->NewConstantIterator(MemoryManager::Global())); + while (iter->HasNext()) { + ASSERT_OK_AND_ASSIGN(absl::string_view name, iter->NextName()); + names.push_back(std::string(name)); + } + + EXPECT_THAT(names, + UnorderedElementsAre("TEST_ENUM_UNSPECIFIED", "TEST_ENUM_1", + "TEST_ENUM_2", "TEST_ENUM_3")); + EXPECT_THAT(iter->NextName(), + StatusIs(absl::StatusCode::kFailedPrecondition)); + + std::vector numbers; + ASSERT_OK_AND_ASSIGN(iter, + enum_type->NewConstantIterator(MemoryManager::Global())); + while (iter->HasNext()) { + ASSERT_OK_AND_ASSIGN(numbers.emplace_back(), iter->NextNumber()); + } + + EXPECT_THAT(numbers, UnorderedElementsAre(0, 10, 20, 30)); + EXPECT_THAT(iter->NextNumber(), + StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST(CelTypeRegistryTest, TestRegisterBuiltInEnum) { CelTypeRegistry registry; - ASSERT_THAT(registry.enums_map(), Contains(Key("google.protobuf.NullValue"))); - EXPECT_THAT(registry.enums_map().at("google.protobuf.NullValue"), - UnorderedElementsAre(EqualsEnumerator("NULL_VALUE", 0))); + ASSERT_THAT(registry.resolveable_enums(), + Contains(Key("google.protobuf.NullValue"))); + EXPECT_THAT(registry.resolveable_enums() + .at("google.protobuf.NullValue") + ->FindConstantByName("NULL_VALUE"), + IsOkAndHolds(Optional(Truly( + [](const EnumType::Constant& c) { return c.number == 0; })))); } TEST(CelTypeRegistryTest, TestRegisterTypeName) { From 1d4e35ec3d01fa434f97d1e946b43ca36810a68f Mon Sep 17 00:00:00 2001 From: jdtatum Date: Mon, 15 May 2023 22:21:19 +0000 Subject: [PATCH 274/303] Add new method to inspect registered enum set for client tests. PiperOrigin-RevId: 532243263 --- eval/public/cel_type_registry.h | 18 ++++++++++++++++++ eval/public/cel_type_registry_test.cc | 15 ++++++--------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index ae622fef9..1f563dc78 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -89,6 +89,7 @@ class CelTypeRegistry { absl::string_view fully_qualified_type_name) const; // Return the set of enums configured within the type registry. + ABSL_DEPRECATED("Use GetRegisteredEnums to validate RegisterEnum calls.") inline const absl::flat_hash_set& Enums() const { return enums_; @@ -101,6 +102,23 @@ class CelTypeRegistry { return resolveable_enums_; } + // Return the registered enums configured within the type registry. + // + // This is provided for validating registry setup, it should not be used + // internally. + // + // Invalidated whenever registered enums are updated. + absl::flat_hash_set ListResolveableEnums() const { + absl::flat_hash_set result; + result.reserve(resolveable_enums_.size()); + + for (const auto& entry : resolveable_enums_) { + result.insert(entry.first); + } + + return result; + } + private: mutable absl::Mutex mutex_; // node_hash_set provides pointer-stability, which is required for the diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 59840f449..68bd3f622 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -102,7 +102,8 @@ struct RegisterEnumDescriptorTestT { // Portable version doesn't support registering at this time. CelTypeRegistry registry; - EXPECT_THAT(registry.Enums(), IsEmpty()); + EXPECT_THAT(registry.ListResolveableEnums(), + UnorderedElementsAre("google.protobuf.NullValue")); } }; @@ -114,14 +115,10 @@ struct RegisterEnumDescriptorTestT< CelTypeRegistry registry; registry.Register(google::protobuf::GetEnumDescriptor()); - absl::flat_hash_set enum_set; - for (auto enum_desc : registry.Enums()) { - enum_set.insert(enum_desc->full_name()); - } - absl::flat_hash_set expected_set{ - "google.protobuf.NullValue", - "google.api.expr.runtime.TestMessage.TestEnum"}; - EXPECT_THAT(enum_set, Eq(expected_set)); + EXPECT_THAT( + registry.ListResolveableEnums(), + UnorderedElementsAre("google.protobuf.NullValue", + "google.api.expr.runtime.TestMessage.TestEnum")); EXPECT_THAT( registry.resolveable_enums(), From 10be43b8d06cdba4610182dedf9d2413b93819c2 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 16 May 2023 02:35:53 +0000 Subject: [PATCH 275/303] Separate container operators from the core builtin function registrar. PiperOrigin-RevId: 532299260 --- eval/public/BUILD | 47 ++++++ eval/public/builtin_func_registrar.cc | 103 +----------- eval/public/builtin_func_registrar.h | 1 + eval/public/container_function_registrar.cc | 147 ++++++++++++++++++ eval/public/container_function_registrar.h | 36 +++++ .../container_function_registrar_test.cc | 93 +++++++++++ 6 files changed, 327 insertions(+), 100 deletions(-) create mode 100644 eval/public/container_function_registrar.cc create mode 100644 eval/public/container_function_registrar.h create mode 100644 eval/public/container_function_registrar_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index 62857817b..0bbb317a8 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -289,6 +289,7 @@ cc_library( ":cel_options", ":cel_value", ":comparison_functions", + ":container_function_registrar", ":equality_function_registrar", ":logical_function_registrar", ":portable_cel_function_adapter", @@ -307,6 +308,7 @@ cc_library( "//internal:time", "//internal:utf8", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", @@ -426,6 +428,51 @@ cc_test( ], ) +cc_library( + name = "container_function_registrar", + srcs = [ + "container_function_registrar.cc", + ], + hdrs = [ + "container_function_registrar.h", + ], + deps = [ + ":cel_function_registry", + ":cel_options", + ":portable_cel_function_adapter", + "//base:builtins", + "//base:function_adapter", + "//base:handle", + "//base:value", + "//eval/eval:mutable_list_impl", + "//eval/internal:interop", + "//eval/public/containers:container_backed_list_impl", + "//extensions/protobuf:memory_manager", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "container_function_registrar_test", + size = "small", + srcs = [ + "container_function_registrar_test.cc", + ], + deps = [ + ":activation", + ":cel_expr_builder_factory", + ":cel_expression", + ":cel_value", + ":container_function_registrar", + ":equality_function_registrar", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/testing:matchers", + "//internal:testing", + "//parser", + ], +) + cc_library( name = "logical_function_registrar", srcs = [ diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index fc6c23d12..451312e80 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -15,19 +15,19 @@ #include "eval/public/builtin_func_registrar.h" #include -#include #include #include #include #include -#include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" +#include "absl/time/civil_time.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "base/function_adapter.h" @@ -47,6 +47,7 @@ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" +#include "eval/public/container_function_registrar.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/equality_function_registrar.h" #include "eval/public/logical_function_registrar.h" @@ -358,24 +359,6 @@ CelValue HeterogeneousEqualityIn(Arena* arena, CelValue value, return CelValue::CreateBool(false); } -// AppendList will append the elements in value2 to value1. -// -// This call will only be invoked within comprehensions where `value1` is an -// intermediate result which cannot be directly assigned or co-mingled with a -// user-provided list. -const CelList* AppendList(Arena* arena, const CelList* value1, - const CelList* value2) { - // The `value1` object cannot be directly addressed and is an intermediate - // variable. Once the comprehension completes this value will in effect be - // treated as immutable. - MutableListImpl* mutable_list = const_cast( - cel::internal::down_cast(value1)); - for (int i = 0; i < value2->size(); i++) { - mutable_list->Append((*value2).Get(arena, i)); - } - return mutable_list; -} - // Concatenation for string type. absl::StatusOr> ConcatString(ValueFactory& factory, const StringValue& value1, @@ -392,43 +375,6 @@ absl::StatusOr> ConcatBytes(ValueFactory& factory, absl::StrCat(value1.ToString(), value2.ToString())); } -// Concatenation for CelList type. -absl::StatusOr> ConcatList(ValueFactory& factory, - const Handle& value1, - const Handle& value2) { - std::vector joined_values; - - int size1 = value1->size(); - if (size1 == 0) { - return value2; - } - int size2 = value2->size(); - if (size2 == 0) { - return value1; - } - joined_values.reserve(size1 + size2); - - Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( - factory.memory_manager()); - - ListValue::GetContext context(factory); - for (int i = 0; i < size1; i++) { - CEL_ASSIGN_OR_RETURN(Handle elem, value1->Get(context, i)); - joined_values.push_back( - cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); - } - for (int i = 0; i < size2; i++) { - CEL_ASSIGN_OR_RETURN(Handle elem, value2->Get(context, i)); - joined_values.push_back( - cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); - } - - auto concatenated = - Arena::Create(arena, joined_values); - - return cel::interop_internal::CreateLegacyListValue(concatenated); -} - // Timestamp absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, absl::TimeZone::CivilInfo* breakdown) { @@ -1600,49 +1546,6 @@ absl::Status RegisterTimeFunctions(CelFunctionRegistry* registry, return absl::OkStatus(); } - -int64_t MapSizeImpl(ValueFactory&, const MapValue& value) { - return value.size(); -} - -int64_t ListSizeImpl(ValueFactory&, const ListValue& value) { - return value.size(); -} - -absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - // receiver style = true/false - // Support both the global and receiver style size() for lists and maps. - for (bool receiver_style : {true, false}) { - CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor( - builtin::kSize, receiver_style), - UnaryFunctionAdapter::WrapFunction( - ListSizeImpl))); - - CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor( - builtin::kSize, receiver_style), - UnaryFunctionAdapter::WrapFunction( - MapSizeImpl))); - } - - if (options.enable_list_concat) { - CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter>, const ListValue&, - const ListValue&>::CreateDescriptor(builtin::kAdd, - false), - BinaryFunctionAdapter< - absl::StatusOr>, const Handle&, - const Handle&>::WrapFunction(ConcatList))); - } - - return registry->Register(PortableBinaryFunctionAdapter< - const CelList*, const CelList*, - const CelList*>::Create(builtin::kRuntimeListAppend, - false, AppendList)); -} - } // namespace absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, diff --git a/eval/public/builtin_func_registrar.h b/eval/public/builtin_func_registrar.h index 4afaaf1a6..636a30820 100644 --- a/eval/public/builtin_func_registrar.h +++ b/eval/public/builtin_func_registrar.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ +#include "absl/status/status.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" diff --git a/eval/public/container_function_registrar.cc b/eval/public/container_function_registrar.cc new file mode 100644 index 000000000..8489336ef --- /dev/null +++ b/eval/public/container_function_registrar.cc @@ -0,0 +1,147 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/container_function_registrar.h" + +#include + +#include "absl/status/status.h" +#include "base/builtins.h" +#include "base/function_adapter.h" +#include "base/handle.h" +#include "base/value.h" +#include "base/value_factory.h" +#include "base/values/list_value.h" +#include "base/values/map_value.h" +#include "eval/eval/mutable_list_impl.h" +#include "eval/internal/interop.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "extensions/protobuf/memory_manager.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using ::cel::BinaryFunctionAdapter; +using ::cel::Handle; +using ::cel::ListValue; +using ::cel::MapValue; +using ::cel::UnaryFunctionAdapter; +using ::cel::Value; +using ::cel::ValueFactory; +using ::google::protobuf::Arena; + +int64_t MapSizeImpl(ValueFactory&, const MapValue& value) { + return value.size(); +} + +int64_t ListSizeImpl(ValueFactory&, const ListValue& value) { + return value.size(); +} + +// Concatenation for CelList type. +absl::StatusOr> ConcatList(ValueFactory& factory, + const Handle& value1, + const Handle& value2) { + std::vector joined_values; + + int size1 = value1->size(); + if (size1 == 0) { + return value2; + } + int size2 = value2->size(); + if (size2 == 0) { + return value1; + } + joined_values.reserve(size1 + size2); + + google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena( + factory.memory_manager()); + + ListValue::GetContext context(factory); + for (int i = 0; i < size1; i++) { + CEL_ASSIGN_OR_RETURN(Handle elem, value1->Get(context, i)); + joined_values.push_back( + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); + } + for (int i = 0; i < size2; i++) { + CEL_ASSIGN_OR_RETURN(Handle elem, value2->Get(context, i)); + joined_values.push_back( + cel::interop_internal::ModernValueToLegacyValueOrDie(arena, elem)); + } + + auto concatenated = + Arena::Create(arena, joined_values); + + return cel::interop_internal::CreateLegacyListValue(concatenated); +} + +// AppendList will append the elements in value2 to value1. +// +// This call will only be invoked within comprehensions where `value1` is an +// intermediate result which cannot be directly assigned or co-mingled with a +// user-provided list. +const CelList* AppendList(Arena* arena, const CelList* value1, + const CelList* value2) { + // The `value1` object cannot be directly addressed and is an intermediate + // variable. Once the comprehension completes this value will in effect be + // treated as immutable. + MutableListImpl* mutable_list = const_cast( + cel::internal::down_cast(value1)); + for (int i = 0; i < value2->size(); i++) { + mutable_list->Append((*value2).Get(arena, i)); + } + return mutable_list; +} +} // namespace + +absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // receiver style = true/false + // Support both the global and receiver style size() for lists and maps. + for (bool receiver_style : {true, false}) { + CEL_RETURN_IF_ERROR(registry->Register( + cel::UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kSize, receiver_style), + UnaryFunctionAdapter::WrapFunction( + ListSizeImpl))); + + CEL_RETURN_IF_ERROR(registry->Register( + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kSize, receiver_style), + UnaryFunctionAdapter::WrapFunction( + MapSizeImpl))); + } + + if (options.enable_list_concat) { + CEL_RETURN_IF_ERROR(registry->Register( + BinaryFunctionAdapter< + absl::StatusOr>, const ListValue&, + const ListValue&>::CreateDescriptor(cel::builtin::kAdd, false), + BinaryFunctionAdapter< + absl::StatusOr>, const Handle&, + const Handle&>::WrapFunction(ConcatList))); + } + + return registry->Register( + PortableBinaryFunctionAdapter< + const CelList*, const CelList*, + const CelList*>::Create(cel::builtin::kRuntimeListAppend, false, + AppendList)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/container_function_registrar.h b/eval/public/container_function_registrar.h new file mode 100644 index 000000000..9ce268439 --- /dev/null +++ b/eval/public/container_function_registrar.h @@ -0,0 +1,36 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +// Register built in container functions. +// +// Most users should prefer to use RegisterBuiltinFunctions. +// +// This call is included in RegisterBuiltinFunctions -- calling both +// RegisterBuiltinFunctions and RegisterContainerFunctions directly on the same +// registry will result in an error. +absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ diff --git a/eval/public/container_function_registrar_test.cc b/eval/public/container_function_registrar_test.cc new file mode 100644 index 000000000..675254974 --- /dev/null +++ b/eval/public/container_function_registrar_test.cc @@ -0,0 +1,93 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/container_function_registrar.h" + +#include +#include + +#include "eval/public/activation.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/equality_function_registrar.h" +#include "eval/public/testing/matchers.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace google::api::expr::runtime { +namespace { + +using testing::ValuesIn; + +struct TestCase { + std::string test_name; + std::string expr; + absl::StatusOr result = CelValue::CreateBool(true); +}; + +const CelList& CelNumberListExample() { + static ContainerBackedListImpl* example = + new ContainerBackedListImpl({CelValue::CreateInt64(1)}); + return *example; +} + +void ExpectResult(const TestCase& test_case) { + auto parsed_expr = parser::Parse(test_case.expr); + ASSERT_OK(parsed_expr); + const Expr& expr_ast = parsed_expr->expr(); + const SourceInfo& source_info = parsed_expr->source_info(); + InterpreterOptions options; + options.enable_timestamp_duration_overflow_errors = true; + options.enable_comprehension_list_append = true; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterContainerFunctions(builder->GetRegistry(), options)); + // Needed to avoid error - No overloads provided for FunctionStep creation. + ASSERT_OK(RegisterEqualityFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr_ast, &source_info)); + + Activation activation; + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); +} + +using ContainerFunctionParamsTest = testing::TestWithParam; +TEST_P(ContainerFunctionParamsTest, StandardFunctions) { + ExpectResult(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + ContainerFunctionParamsTest, ContainerFunctionParamsTest, + ValuesIn( + {{"FilterNumbers", "[1, 2, 3].filter(num, num == 1)", + CelValue::CreateList(&CelNumberListExample())}, + {"ListConcatEmptyInputs", "[] + [] == []", CelValue::CreateBool(true)}, + {"ListConcatRightEmpty", "[1] + [] == [1]", + CelValue::CreateBool(true)}, + {"ListConcatLeftEmpty", "[] + [1] == [1]", CelValue::CreateBool(true)}, + {"ListConcat", "[2] + [1] == [2, 1]", CelValue::CreateBool(true)}, + {"ListSize", "[1, 2, 3].size() == 3", CelValue::CreateBool(true)}, + {"MapSize", "{1: 2, 2: 4}.size() == 2", CelValue::CreateBool(true)}, + {"EmptyListSize", "size({}) == 0", CelValue::CreateBool(true)}}), + [](const testing::TestParamInfo& + info) { return info.param.test_name; }); + +} // namespace +} // namespace google::api::expr::runtime From 2cddf41df88d5cf1ebd22ab663cc6e40673a76c2 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Tue, 16 May 2023 18:55:22 +0000 Subject: [PATCH 276/303] Remove stored protobuf enum descriptors from CelTypeRegistry. PiperOrigin-RevId: 532529938 --- eval/public/BUILD | 2 -- eval/public/cel_type_registry.cc | 39 ++------------------------------ eval/public/cel_type_registry.h | 12 ---------- 3 files changed, 2 insertions(+), 51 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 0bbb317a8..b030b86c5 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -951,8 +951,6 @@ cc_library( "//base:value", "//eval/internal:interop", "//eval/public/structs:legacy_type_provider", - "//internal:no_destructor", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 60890a0b7..caff983df 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -2,12 +2,9 @@ #include #include -#include #include #include -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/descriptor.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/strings/string_view.h" @@ -19,7 +16,7 @@ #include "base/types/enum_type.h" #include "base/value.h" #include "eval/internal/interop.h" -#include "internal/no_destructor.h" +#include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { @@ -49,7 +46,6 @@ const absl::node_hash_set& GetCoreTypes() { return *kCoreTypes; } -using DescriptorSet = absl::flat_hash_set; using EnumMap = absl::flat_hash_map>; // Type factory for ref-counted type instances. @@ -154,35 +150,6 @@ void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, registry.RegisterEnum(desc->full_name(), std::move(enumerators)); } -// Portable version. Add overloads for specific core supported enums. -template -struct EnumAdderT { - template - void AddEnum(DescriptorSet&) {} -}; - -template -struct EnumAdderT, void>::type> { - template - void AddEnum(DescriptorSet& set) { - set.insert(google::protobuf::GetEnumDescriptor()); - } -}; - -// Enable loading the linked descriptor if using the full proto runtime. -// Otherwise, only support explicitly defined enums. -using EnumAdder = EnumAdderT; - -const absl::flat_hash_set& GetCoreEnums() { - static cel::internal::NoDestructor kCoreEnums([]() { - absl::flat_hash_set instance; - EnumAdder().AddEnum(instance); - return instance; - }()); - return *kCoreEnums; -} - } // namespace absl::StatusOr> @@ -207,8 +174,7 @@ ResolveableEnumType::FindConstantByNumber(int64_t number) const { return absl::nullopt; } -CelTypeRegistry::CelTypeRegistry() - : types_(GetCoreTypes()), enums_(GetCoreEnums()) { +CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()) { RegisterEnum("google.protobuf.NullValue", {{"NULL_VALUE", 0}}); } @@ -219,7 +185,6 @@ void CelTypeRegistry::Register(std::string fully_qualified_type_name) { } void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { - enums_.insert(enum_descriptor); AddEnumFromDescriptor(enum_descriptor, *this); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 1f563dc78..36e4b1db8 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -6,9 +6,6 @@ #include #include -#include "google/protobuf/descriptor.h" -#include "absl/base/attributes.h" -#include "absl/base/call_once.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -88,13 +85,6 @@ class CelTypeRegistry { cel::Handle FindType( absl::string_view fully_qualified_type_name) const; - // Return the set of enums configured within the type registry. - ABSL_DEPRECATED("Use GetRegisteredEnums to validate RegisterEnum calls.") - inline const absl::flat_hash_set& Enums() - const { - return enums_; - } - // Return the registered enums configured within the type registry in the // internal format that can be identified as int constants at plan time. const absl::flat_hash_map>& @@ -124,8 +114,6 @@ class CelTypeRegistry { // node_hash_set provides pointer-stability, which is required for the // strings backing CelType objects. mutable absl::node_hash_set types_ ABSL_GUARDED_BY(mutex_); - // Set of registered enums. - absl::flat_hash_set enums_; // Internal representation for enums. absl::flat_hash_map> resolveable_enums_; From e0a4586a402b77544c00a27f17a2b8f4544cb180 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 16 May 2023 19:03:53 +0000 Subject: [PATCH 277/303] Fix bug related to the function registry and dangling references to descriptors PiperOrigin-RevId: 532532458 --- eval/public/cel_function_registry.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index e8d484605..e1fb69074 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -44,7 +44,10 @@ class CelFunctionRegistry { // Function registration should be performed prior to // CelExpression creation. absl::Status Register(std::unique_ptr function) { - return Register(function->descriptor(), std::move(function)); + // We need to copy the descriptor, otherwise there is no guarantee that the + // lvalue reference to the descriptor is valid as function may be destroyed. + auto descriptor = function->descriptor(); + return Register(descriptor, std::move(function)); } absl::Status Register(const cel::FunctionDescriptor& descriptor, From 6f2619fe18c04b512fe32a8c6b48dfcf15a487a9 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 16 May 2023 19:05:16 +0000 Subject: [PATCH 278/303] Fix ASan issue due to odd lifetimes regarding `absl::string_view` and `absl::Span` PiperOrigin-RevId: 532532876 --- base/types/dyn_type.cc | 3 ++- base/types/list_type.cc | 4 ++-- base/types/map_type.cc | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/base/types/dyn_type.cc b/base/types/dyn_type.cc index 24a8b812d..fab2ceb9e 100644 --- a/base/types/dyn_type.cc +++ b/base/types/dyn_type.cc @@ -23,7 +23,8 @@ CEL_INTERNAL_TYPE_IMPL(DynType); absl::Span DynType::aliases() const { // Currently google.protobuf.Value also resolves to dyn. - return absl::MakeConstSpan({absl::string_view("google.protobuf.Value")}); + static constexpr absl::string_view kAliases[] = {"google.protobuf.Value"}; + return absl::MakeConstSpan(kAliases); } } // namespace cel diff --git a/base/types/list_type.cc b/base/types/list_type.cc index cda767bf7..60a1d5e0a 100644 --- a/base/types/list_type.cc +++ b/base/types/list_type.cc @@ -29,10 +29,10 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(ListType); absl::Span ListType::aliases() const { + static constexpr absl::string_view kAliases[] = {"google.protobuf.ListValue"}; if (element()->kind() == Kind::kDyn) { // Currently google.protobuf.ListValue resolves to list. - return absl::MakeConstSpan( - {absl::string_view("google.protobuf.ListValue")}); + return absl::MakeConstSpan(kAliases); } return absl::Span(); } diff --git a/base/types/map_type.cc b/base/types/map_type.cc index 82bfac3d2..5fdf9e2a7 100644 --- a/base/types/map_type.cc +++ b/base/types/map_type.cc @@ -29,9 +29,10 @@ namespace cel { CEL_INTERNAL_TYPE_IMPL(MapType); absl::Span MapType::aliases() const { + static constexpr absl::string_view kAliases[] = {"google.protobuf.Struct"}; if (key()->kind() == Kind::kString && value()->kind() == Kind::kDyn) { // Currently google.protobuf.Struct resolves to map. - return absl::MakeConstSpan({absl::string_view("google.protobuf.Struct")}); + return absl::MakeConstSpan(kAliases); } return absl::Span(); } From 9dfb9feb03a7a66de810bae1fea3f16e11745840 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 16 May 2023 19:22:26 +0000 Subject: [PATCH 279/303] Add more tests to `ProtoStructValue` PiperOrigin-RevId: 532538180 --- extensions/protobuf/struct_value_test.cc | 470 +++++++++++++++++++++++ extensions/protobuf/type_test.cc | 11 + 2 files changed, 481 insertions(+) diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index e83beded4..fe15dd5da 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -18,7 +18,9 @@ #include #include +#include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/functional/function_ref.h" @@ -1970,6 +1972,84 @@ TEST_P(ProtoStructValueTest, BoolMessageMapHasField) { std::make_pair(true, TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, BoolAnyMapHasField) { + TestMapHasField(memory_manager(), "map_bool_any", + &TestAllTypes::mutable_map_bool_any, + std::make_pair(true, google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, BoolStructMapHasField) { + TestMapHasField(memory_manager(), "map_bool_struct", + &TestAllTypes::mutable_map_bool_struct, + std::make_pair(true, google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, BoolValueMapHasField) { + TestMapHasField(memory_manager(), "map_bool_value", + &TestAllTypes::mutable_map_bool_value, + std::make_pair(true, google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, BoolListValueMapHasField) { + TestMapHasField(memory_manager(), "map_bool_list_value", + &TestAllTypes::mutable_map_bool_list_value, + std::make_pair(true, google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, BoolInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_int64_wrapper", + &TestAllTypes::mutable_map_bool_int64_wrapper, + std::make_pair(true, google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, BoolInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_int32_wrapper", + &TestAllTypes::mutable_map_bool_int32_wrapper, + std::make_pair(true, google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, BoolDoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_double_wrapper", + &TestAllTypes::mutable_map_bool_double_wrapper, + std::make_pair(true, google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, BoolFloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_float_wrapper", + &TestAllTypes::mutable_map_bool_float_wrapper, + std::make_pair(true, google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, BoolUInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_uint64_wrapper", + &TestAllTypes::mutable_map_bool_uint64_wrapper, + std::make_pair(true, google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, BoolUInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_uint32_wrapper", + &TestAllTypes::mutable_map_bool_uint32_wrapper, + std::make_pair(true, google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, BoolStringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_string_wrapper", + &TestAllTypes::mutable_map_bool_string_wrapper, + std::make_pair(true, google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, BoolBoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_bool_wrapper", + &TestAllTypes::mutable_map_bool_bool_wrapper, + std::make_pair(true, google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, BoolBytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_bool_bytes_wrapper", + &TestAllTypes::mutable_map_bool_bytes_wrapper, + std::make_pair(true, google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, Int32NullValueMapHasField) { TestMapHasField(memory_manager(), "map_int32_null_value", &TestAllTypes::mutable_map_int32_null_value, @@ -2052,6 +2132,84 @@ TEST_P(ProtoStructValueTest, Int32MessageMapHasField) { std::make_pair(1, TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, Int32AnyMapHasField) { + TestMapHasField(memory_manager(), "map_int32_any", + &TestAllTypes::mutable_map_int32_any, + std::make_pair(1, google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, Int32StructMapHasField) { + TestMapHasField(memory_manager(), "map_int32_struct", + &TestAllTypes::mutable_map_int32_struct, + std::make_pair(1, google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, Int32ValueMapHasField) { + TestMapHasField(memory_manager(), "map_int32_value", + &TestAllTypes::mutable_map_int32_value, + std::make_pair(1, google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, Int32ListValueMapHasField) { + TestMapHasField(memory_manager(), "map_int32_list_value", + &TestAllTypes::mutable_map_int32_list_value, + std::make_pair(1, google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, Int32Int64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_int64_wrapper", + &TestAllTypes::mutable_map_int32_int64_wrapper, + std::make_pair(1, google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, Int32Int32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_int32_wrapper", + &TestAllTypes::mutable_map_int32_int32_wrapper, + std::make_pair(1, google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, Int32DoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_double_wrapper", + &TestAllTypes::mutable_map_int32_double_wrapper, + std::make_pair(1, google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, Int32FloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_float_wrapper", + &TestAllTypes::mutable_map_int32_float_wrapper, + std::make_pair(1, google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, Int32UInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_uint64_wrapper", + &TestAllTypes::mutable_map_int32_uint64_wrapper, + std::make_pair(1, google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, Int32UInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_uint32_wrapper", + &TestAllTypes::mutable_map_int32_uint32_wrapper, + std::make_pair(1, google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, Int32StringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_string_wrapper", + &TestAllTypes::mutable_map_int32_string_wrapper, + std::make_pair(1, google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, Int32BoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_bool_wrapper", + &TestAllTypes::mutable_map_int32_bool_wrapper, + std::make_pair(1, google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, Int32BytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int32_bytes_wrapper", + &TestAllTypes::mutable_map_int32_bytes_wrapper, + std::make_pair(1, google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, Int64NullValueMapHasField) { TestMapHasField(memory_manager(), "map_int64_null_value", &TestAllTypes::mutable_map_int64_null_value, @@ -2134,6 +2292,84 @@ TEST_P(ProtoStructValueTest, Int64MessageMapHasField) { std::make_pair(1, TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, Int64AnyMapHasField) { + TestMapHasField(memory_manager(), "map_int64_any", + &TestAllTypes::mutable_map_int64_any, + std::make_pair(1, google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, Int64StructMapHasField) { + TestMapHasField(memory_manager(), "map_int64_struct", + &TestAllTypes::mutable_map_int64_struct, + std::make_pair(1, google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, Int64ValueMapHasField) { + TestMapHasField(memory_manager(), "map_int64_value", + &TestAllTypes::mutable_map_int64_value, + std::make_pair(1, google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, Int64ListValueMapHasField) { + TestMapHasField(memory_manager(), "map_int64_list_value", + &TestAllTypes::mutable_map_int64_list_value, + std::make_pair(1, google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, Int64Int64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_int64_wrapper", + &TestAllTypes::mutable_map_int64_int64_wrapper, + std::make_pair(1, google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, Int64Int32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_int32_wrapper", + &TestAllTypes::mutable_map_int64_int32_wrapper, + std::make_pair(1, google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, Int64DoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_double_wrapper", + &TestAllTypes::mutable_map_int64_double_wrapper, + std::make_pair(1, google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, Int64FloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_float_wrapper", + &TestAllTypes::mutable_map_int64_float_wrapper, + std::make_pair(1, google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, Int64UInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_uint64_wrapper", + &TestAllTypes::mutable_map_int64_uint64_wrapper, + std::make_pair(1, google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, Int64UInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_uint32_wrapper", + &TestAllTypes::mutable_map_int64_uint32_wrapper, + std::make_pair(1, google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, Int64StringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_string_wrapper", + &TestAllTypes::mutable_map_int64_string_wrapper, + std::make_pair(1, google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, Int64BoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_bool_wrapper", + &TestAllTypes::mutable_map_int64_bool_wrapper, + std::make_pair(1, google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, Int64BytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_int64_bytes_wrapper", + &TestAllTypes::mutable_map_int64_bytes_wrapper, + std::make_pair(1, google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, Uint32NullValueMapHasField) { TestMapHasField(memory_manager(), "map_uint32_null_value", &TestAllTypes::mutable_map_uint32_null_value, @@ -2218,6 +2454,84 @@ TEST_P(ProtoStructValueTest, Uint32MessageMapHasField) { std::make_pair(1u, TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, Uint32AnyMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_any", + &TestAllTypes::mutable_map_uint32_any, + std::make_pair(1, google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, Uint32StructMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_struct", + &TestAllTypes::mutable_map_uint32_struct, + std::make_pair(1, google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, Uint32ValueMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_value", + &TestAllTypes::mutable_map_uint32_value, + std::make_pair(1, google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, Uint32ListValueMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_list_value", + &TestAllTypes::mutable_map_uint32_list_value, + std::make_pair(1, google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, Uint32Int64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_int64_wrapper", + &TestAllTypes::mutable_map_uint32_int64_wrapper, + std::make_pair(1, google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, Uint32Int32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_int32_wrapper", + &TestAllTypes::mutable_map_uint32_int32_wrapper, + std::make_pair(1, google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, Uint32DoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_double_wrapper", + &TestAllTypes::mutable_map_uint32_double_wrapper, + std::make_pair(1, google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, Uint32FloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_float_wrapper", + &TestAllTypes::mutable_map_uint32_float_wrapper, + std::make_pair(1, google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, Uint32UInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_uint64_wrapper", + &TestAllTypes::mutable_map_uint32_uint64_wrapper, + std::make_pair(1, google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, Uint32UInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_uint32_wrapper", + &TestAllTypes::mutable_map_uint32_uint32_wrapper, + std::make_pair(1, google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, Uint32StringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_string_wrapper", + &TestAllTypes::mutable_map_uint32_string_wrapper, + std::make_pair(1, google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, Uint32BoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_bool_wrapper", + &TestAllTypes::mutable_map_uint32_bool_wrapper, + std::make_pair(1, google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, Uint32BytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint32_bytes_wrapper", + &TestAllTypes::mutable_map_uint32_bytes_wrapper, + std::make_pair(1, google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, Uint64NullValueMapHasField) { TestMapHasField(memory_manager(), "map_uint64_null_value", &TestAllTypes::mutable_map_uint64_null_value, @@ -2302,6 +2616,84 @@ TEST_P(ProtoStructValueTest, Uint64MessageMapHasField) { std::make_pair(1u, TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, Uint64AnyMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_any", + &TestAllTypes::mutable_map_uint64_any, + std::make_pair(1, google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, Uint64StructMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_struct", + &TestAllTypes::mutable_map_uint64_struct, + std::make_pair(1, google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, Uint64ValueMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_value", + &TestAllTypes::mutable_map_uint64_value, + std::make_pair(1, google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, Uint64ListValueMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_list_value", + &TestAllTypes::mutable_map_uint64_list_value, + std::make_pair(1, google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, Uint64Int64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_int64_wrapper", + &TestAllTypes::mutable_map_uint64_int64_wrapper, + std::make_pair(1, google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, Uint64Int32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_int32_wrapper", + &TestAllTypes::mutable_map_uint64_int32_wrapper, + std::make_pair(1, google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, Uint64DoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_double_wrapper", + &TestAllTypes::mutable_map_uint64_double_wrapper, + std::make_pair(1, google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, Uint64FloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_float_wrapper", + &TestAllTypes::mutable_map_uint64_float_wrapper, + std::make_pair(1, google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, Uint64UInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_uint64_wrapper", + &TestAllTypes::mutable_map_uint64_uint64_wrapper, + std::make_pair(1, google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, Uint64UInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_uint32_wrapper", + &TestAllTypes::mutable_map_uint64_uint32_wrapper, + std::make_pair(1, google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, Uint64StringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_string_wrapper", + &TestAllTypes::mutable_map_uint64_string_wrapper, + std::make_pair(1, google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, Uint64BoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_bool_wrapper", + &TestAllTypes::mutable_map_uint64_bool_wrapper, + std::make_pair(1, google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, Uint64BytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_uint64_bytes_wrapper", + &TestAllTypes::mutable_map_uint64_bytes_wrapper, + std::make_pair(1, google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, StringNullValueMapHasField) { TestMapHasField(memory_manager(), "map_string_null_value", &TestAllTypes::mutable_map_string_null_value, @@ -2386,6 +2778,84 @@ TEST_P(ProtoStructValueTest, StringMessageMapHasField) { std::make_pair("foo", TestAllTypes::NestedMessage())); } +TEST_P(ProtoStructValueTest, StringAnyMapHasField) { + TestMapHasField(memory_manager(), "map_string_any", + &TestAllTypes::mutable_map_string_any, + std::make_pair("foo", google::protobuf::Any())); +} + +TEST_P(ProtoStructValueTest, StringStructMapHasField) { + TestMapHasField(memory_manager(), "map_string_struct", + &TestAllTypes::mutable_map_string_struct, + std::make_pair("foo", google::protobuf::Struct())); +} + +TEST_P(ProtoStructValueTest, StringValueMapHasField) { + TestMapHasField(memory_manager(), "map_string_value", + &TestAllTypes::mutable_map_string_value, + std::make_pair("foo", google::protobuf::Value())); +} + +TEST_P(ProtoStructValueTest, StringListValueMapHasField) { + TestMapHasField(memory_manager(), "map_string_list_value", + &TestAllTypes::mutable_map_string_list_value, + std::make_pair("foo", google::protobuf::ListValue())); +} + +TEST_P(ProtoStructValueTest, StringInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_int64_wrapper", + &TestAllTypes::mutable_map_string_int64_wrapper, + std::make_pair("foo", google::protobuf::Int64Value())); +} + +TEST_P(ProtoStructValueTest, StringInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_int32_wrapper", + &TestAllTypes::mutable_map_string_int32_wrapper, + std::make_pair("foo", google::protobuf::Int32Value())); +} + +TEST_P(ProtoStructValueTest, StringDoubleWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_double_wrapper", + &TestAllTypes::mutable_map_string_double_wrapper, + std::make_pair("foo", google::protobuf::DoubleValue())); +} + +TEST_P(ProtoStructValueTest, StringFloatWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_float_wrapper", + &TestAllTypes::mutable_map_string_float_wrapper, + std::make_pair("foo", google::protobuf::FloatValue())); +} + +TEST_P(ProtoStructValueTest, StringUInt64WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_uint64_wrapper", + &TestAllTypes::mutable_map_string_uint64_wrapper, + std::make_pair("foo", google::protobuf::UInt64Value())); +} + +TEST_P(ProtoStructValueTest, StringUInt32WrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_uint32_wrapper", + &TestAllTypes::mutable_map_string_uint32_wrapper, + std::make_pair("foo", google::protobuf::UInt32Value())); +} + +TEST_P(ProtoStructValueTest, StringStringWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_string_wrapper", + &TestAllTypes::mutable_map_string_string_wrapper, + std::make_pair("foo", google::protobuf::StringValue())); +} + +TEST_P(ProtoStructValueTest, StringBoolWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_bool_wrapper", + &TestAllTypes::mutable_map_string_bool_wrapper, + std::make_pair("foo", google::protobuf::BoolValue())); +} + +TEST_P(ProtoStructValueTest, StringBytesWrapperMapHasField) { + TestMapHasField(memory_manager(), "map_string_bytes_wrapper", + &TestAllTypes::mutable_map_string_bytes_wrapper, + std::make_pair("foo", google::protobuf::BytesValue())); +} + TEST_P(ProtoStructValueTest, BoolNullValueMapGetField) { TestMapGetField(memory_manager(), "map_bool_null_value", "{false: null, true: null}", diff --git a/extensions/protobuf/type_test.cc b/extensions/protobuf/type_test.cc index 067d2f0e1..80338e839 100644 --- a/extensions/protobuf/type_test.cc +++ b/extensions/protobuf/type_test.cc @@ -14,7 +14,9 @@ #include "extensions/protobuf/type.h" +#include "google/protobuf/api.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "absl/status/status.h" #include "base/internal/memory_manager_testing.h" #include "base/testing/type_matchers.h" #include "base/type_factory.h" @@ -26,6 +28,7 @@ namespace cel::extensions { namespace { using ::cel_testing::TypeIs; +using testing::status::CanonicalStatusIs; using cel::internal::IsOkAndHolds; using ProtoTypeTest = ProtoTest<>; @@ -87,6 +90,14 @@ TEST_P(ProtoTypeTest, DynamicWrapperTypes) { IsOkAndHolds(TypeIs())); } +TEST_P(ProtoTypeTest, ResolveNotFound) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + EXPECT_THAT( + ProtoType::Resolve(type_manager, *google::protobuf::Api::descriptor()), + CanonicalStatusIs(absl::StatusCode::kNotFound)); +} + INSTANTIATE_TEST_SUITE_P(ProtoTypeTest, ProtoTypeTest, cel::base_internal::MemoryManagerTestModeAll(), cel::base_internal::MemoryManagerTestModeTupleName); From 79e4a22743f1312ce38717a5e93c07e985a97847 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Tue, 16 May 2023 19:26:18 +0000 Subject: [PATCH 280/303] Update cel::TypeProvider references to be const. PiperOrigin-RevId: 532539270 --- base/type_manager.h | 6 +++--- base/type_provider.h | 3 +++ base/value_factory.h | 4 +++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/base/type_manager.h b/base/type_manager.h index ccefeab42..e6e975f33 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -27,7 +27,7 @@ namespace cel { -// TypeManager is a union of the TypeFactory and TypeRegistry, allowing for both +// TypeManager is a union of the TypeFactory and TypeProvider, allowing for both // the instantiation of type implementations, loading of type implementations, // and registering type implementations. // @@ -44,7 +44,7 @@ class TypeManager final { TypeFactory& type_factory() const { return type_factory_; } - TypeProvider& type_provider() const { return type_provider_; } + const TypeProvider& type_provider() const { return type_provider_; } absl::StatusOr>> ResolveType( absl::string_view name); @@ -56,7 +56,7 @@ class TypeManager final { Handle&& type); TypeFactory& type_factory_; - TypeProvider& type_provider_; + const TypeProvider& type_provider_; absl::Mutex mutex_; // std::string as the key because we also cache types which do not exist. diff --git a/base/type_provider.h b/base/type_provider.h index 3c416132a..c99eab985 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -35,6 +35,9 @@ class TypeFactory; // of the registered providers. If the type can't be resolved, the operation // will result in an error. // +// Type provider implementations must be effectively immutable and threadsafe. +// The type registry uses this property to aggressively cache results. +// // Note: This API is not finalized. Consult the CEL team before introducing new // implementations. class TypeProvider { diff --git a/base/value_factory.h b/base/value_factory.h index 4538aa717..a8dbbdf9e 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -99,7 +99,9 @@ class ValueFactory final { TypeFactory& type_factory() const { return type_manager().type_factory(); } - TypeProvider& type_provider() const { return type_manager().type_provider(); } + const TypeProvider& type_provider() const { + return type_manager().type_provider(); + } TypeManager& type_manager() const { return type_manager_; } From cc7daec753155fc81b9d57aaafde2935692ef5c8 Mon Sep 17 00:00:00 2001 From: jdtatum Date: Tue, 16 May 2023 20:35:02 +0000 Subject: [PATCH 281/303] Migrate regex precompilation to program optimizer interface. PiperOrigin-RevId: 532559575 --- eval/compiler/BUILD | 40 +++- eval/compiler/flat_expr_builder.cc | 114 +-------- eval/compiler/flat_expr_builder.h | 5 - eval/compiler/flat_expr_builder_extensions.cc | 26 +++ eval/compiler/flat_expr_builder_extensions.h | 6 + .../flat_expr_builder_extensions_test.cc | 60 ++++- .../regex_precompilation_optimization.cc | 180 +++++++++++++++ .../regex_precompilation_optimization.h | 29 +++ .../regex_precompilation_optimization_test.cc | 217 ++++++++++++++++++ eval/eval/evaluator_core.h | 2 + eval/eval/regex_match_step.cc | 14 +- eval/public/BUILD | 1 + .../portable_cel_expr_builder_factory.cc | 7 +- 13 files changed, 567 insertions(+), 134 deletions(-) create mode 100644 eval/compiler/regex_precompilation_optimization.cc create mode 100644 eval/compiler/regex_precompilation_optimization.h create mode 100644 eval/compiler/regex_precompilation_optimization_test.cc diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 74af97aad..b680b68c7 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -71,7 +71,6 @@ cc_library( "//eval/eval:ident_step", "//eval/eval:jump_step", "//eval/eval:logic_step", - "//eval/eval:regex_match_step", "//eval/eval:select_step", "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", @@ -362,6 +361,45 @@ cc_test( ], ) +cc_library( + name = "regex_precompilation_optimization", + srcs = ["regex_precompilation_optimization.cc"], + hdrs = ["regex_precompilation_optimization.h"], + deps = [ + ":flat_expr_builder_extensions", + "//base:ast_internal", + "//base:builtins", + "//base:value", + "//base/internal:ast_impl", + "//eval/eval:compiler_constant_step", + "//eval/eval:regex_match_step", + "//internal:casts", + "//internal:rtti", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "regex_precompilation_optimization_test", + srcs = ["regex_precompilation_optimization_test.cc"], + deps = [ + ":flat_expr_builder", + ":flat_expr_builder_extensions", + ":regex_precompilation_optimization", + "//base:ast_internal", + "//base/internal:ast_impl", + "//eval/eval:evaluator_core", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_options", + "//internal:testing", + "//parser", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + package_group( name = "native_api_users", packages = [ diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 2ef8c42b3..490193d94 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -41,7 +41,6 @@ #include "base/ast.h" #include "base/ast_internal.h" #include "base/internal/ast_impl.h" -#include "base/values/string_value.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" @@ -56,7 +55,6 @@ #include "eval/eval/ident_step.h" #include "eval/eval/jump_step.h" #include "eval/eval/logic_step.h" -#include "eval/eval/regex_match_step.h" #include "eval/eval/select_step.h" #include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" @@ -75,7 +73,6 @@ namespace google::api::expr::runtime { namespace { using ::cel::Handle; -using ::cel::StringValue; using ::cel::Value; using ::cel::ast::Ast; using ::cel::ast::internal::AstImpl; @@ -92,67 +89,9 @@ using Comprehension = ::google::api::expr::v1alpha1::Expr::Comprehension; constexpr int64_t kExprIdNotFromAst = -1; -template -bool IsFunctionOverload( - const ExprT& expr, absl::string_view function, absl::string_view overload, - size_t arity, - const absl::flat_hash_map* - reference_map) { - if (reference_map == nullptr || !expr.has_call_expr()) { - return false; - } - const auto& call_expr = expr.call_expr(); - if (call_expr.function() != function) { - return false; - } - if (call_expr.args().size() + (call_expr.has_target() ? 1 : 0) != arity) { - return false; - } - auto reference = reference_map->find(expr.id()); - if (reference != reference_map->end() && - reference->second.overload_id().size() == 1 && - reference->second.overload_id().front() == overload) { - return true; - } - return false; -} - // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; -// Abstraction for deduplicating regular expressions over the course of a single -// create expression call. Should not be used during evaluation. Uses -// std::shared_ptr and std::weak_ptr. -class RegexProgramBuilder final { - public: - explicit RegexProgramBuilder(int max_program_size) - : max_program_size_(max_program_size) {} - - absl::StatusOr> BuildRegexProgram( - std::string pattern) { - auto existing = programs_.find(pattern); - if (existing != programs_.end()) { - if (auto program = existing->second.lock(); program) { - return program; - } - programs_.erase(existing); - } - auto program = std::make_shared(pattern); - if (max_program_size_ > 0 && program->ProgramSize() > max_program_size_) { - return absl::InvalidArgumentError("exceeded RE2 max program size"); - } - if (!program->ok()) { - return absl::InvalidArgumentError("invalid_argument"); - } - programs_.insert({std::move(pattern), program}); - return program; - } - - private: - const int max_program_size_; - absl::flat_hash_map> programs_; -}; - // A convenience wrapper for offset-calculating logic. class Jump { public: @@ -273,7 +212,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { const cel::RuntimeOptions& options, const absl::flat_hash_map>& constant_idents, bool enable_comprehension_vulnerability_check, - bool enable_regex_precompilation, absl::Span> program_optimizers, const absl::flat_hash_map* reference_map, @@ -290,10 +228,8 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { constant_idents_(constant_idents), enable_comprehension_vulnerability_check_( enable_comprehension_vulnerability_check), - enable_regex_precompilation_(enable_regex_precompilation), program_optimizers_(program_optimizers), builder_warnings_(warnings), - regex_program_builder_(options_.regex_max_program_size), reference_map_(reference_map), program_tree_(program_tree), extension_context_(extension_context) {} @@ -550,20 +486,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); auto arguments_matcher = ArgumentsMatcher(num_args); - // Check to see if this is regular expression matching and the pattern is a - // constant. - if (options_.enable_regex && enable_regex_precompilation_ && - IsOptimizeableMatchesCall(*expr, *call_expr)) { - auto program = regex_program_builder_.BuildRegexProgram( - GetConstantString(call_expr->args().back())); - if (!program.ok()) { - SetProgressStatusError(program.status()); - return; - } - AddStep(CreateRegexMatchStep(std::move(program).value(), expr->id())); - return; - } - // Check to see if this is a special case of add that should really be // treated as a list append if (options_.enable_comprehension_list_append && @@ -798,37 +720,6 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { } private: - bool IsConstantString(const cel::ast::internal::Expr& expr) const { - if (expr.has_const_expr() && expr.const_expr().has_string_value()) { - return true; - } - if (!expr.has_ident_expr()) { - return false; - } - auto const_value = constant_idents_.find(expr.ident_expr().name()); - return const_value != constant_idents_.end() && - const_value->second->Is(); - } - - std::string GetConstantString(const cel::ast::internal::Expr& expr) const { - ABSL_ASSERT(IsConstantString(expr)); - if (expr.has_const_expr()) { - return expr.const_expr().string_value(); - } - return constant_idents_.find(expr.ident_expr().name()) - ->second.As() - ->ToString(); - } - - bool IsOptimizeableMatchesCall( - const cel::ast::internal::Expr& expr, - const cel::ast::internal::Call& call_expr) const { - return IsFunctionOverload(expr, - google::api::expr::runtime::builtin::kRegexMatch, - "matches_string", 2, reference_map_) && - IsConstantString(call_expr.args().back()); - } - const google::api::expr::runtime::Resolver& resolver_; google::api::expr::runtime::ExecutionPath* execution_path_; absl::Status progress_status_; @@ -860,12 +751,10 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { std::stack comprehension_stack_; bool enable_comprehension_vulnerability_check_; - bool enable_regex_precompilation_; absl::Span> program_optimizers_; google::api::expr::runtime::BuilderWarnings* builder_warnings_; - RegexProgramBuilder regex_program_builder_; const absl::flat_hash_map* const reference_map_; @@ -1331,8 +1220,7 @@ FlatExprBuilder::CreateExpressionImpl( optimizer_factory(extension_context, ast_impl)); } FlatExprVisitor visitor(resolver, options_, constant_idents, - enable_comprehension_vulnerability_check_, - enable_regex_precompilation_, optimizers, + enable_comprehension_vulnerability_check_, optimizers, &ast_impl.reference_map(), &execution_path, &warnings_builder, program_tree, extension_context); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 6bb808fec..0324ffcbf 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -71,10 +71,6 @@ class FlatExprBuilder : public CelExpressionBuilder { program_optimizers_.push_back(std::move(optimizer)); } - void set_enable_regex_precompilation(bool enable) { - enable_regex_precompilation_ = enable; - } - absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -105,7 +101,6 @@ class FlatExprBuilder : public CelExpressionBuilder { std::vector> ast_transforms_; std::vector program_optimizers_; - bool enable_regex_precompilation_ = false; bool enable_comprehension_vulnerability_check_ = false; bool constant_folding_ = false; google::protobuf::Arena* constant_arena_ = nullptr; diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index 1ddde6902..3e1c69ac3 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -41,6 +41,32 @@ ExecutionPathView PlannerContext::GetSubplan( .subspan(info.range_start, info.range_len); } +absl::StatusOr PlannerContext::ExtractSubplan( + const cel::ast::internal::Expr& node) { + auto iter = program_tree_.find(&node); + if (iter == program_tree_.end()) { + return absl::InternalError("attempted to rewrite unknown program step"); + } + + ProgramInfo& info = iter->second; + + if (info.range_len == -1) { + // Initial planning for this node hasn't finished. + return absl::InternalError( + "attempted to rewrite program step before completion."); + } + + ExecutionPath out; + out.reserve(info.range_len); + + out.insert(out.begin(), + std::move_iterator(execution_path_.begin() + info.range_start), + std::move_iterator(execution_path_.begin() + info.range_start + + info.range_len)); + + return out; +} + absl::Status PlannerContext::ReplaceSubplan( const cel::ast::internal::Expr& node, ExecutionPath path) { auto iter = program_tree_.find(&node); diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index edd134eab..af2f4862b 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -67,6 +67,12 @@ class PlannerContext { // Note: this is invalidated after a sibling or parent is updated. ExecutionPathView GetSubplan(const cel::ast::internal::Expr& node) const; + // Extract the plan steps for the given expr. + // The backing execution path is not resized -- a later call must + // overwrite the extracted region. + absl::StatusOr ExtractSubplan( + const cel::ast::internal::Expr& node); + // Note: this can only safely be called on the node being visited. absl::Status ReplaceSubplan(const cel::ast::internal::Expr& node, ExecutionPath path); diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc index 805f5bdb6..0c64fd959 100644 --- a/eval/compiler/flat_expr_builder_extensions_test.cc +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -42,7 +42,6 @@ class PlannerContextTest : public testing::Test { : type_registry_(), function_registry_(), resolver_("", function_registry_, &type_registry_) {} - void SetUp() override {} protected: CelTypeRegistry type_registry_; @@ -153,6 +152,65 @@ TEST_F(PlannerContextTest, ReplacePlan) { EXPECT_THAT(context.GetSubplan(b), IsEmpty()); } +TEST_F(PlannerContextTest, ExtractPlan) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + const ExpressionStep* b_step_ptr = path[0].get(); + const ExpressionStep* c_step_ptr = path[1].get(); + const ExpressionStep* a_step_ptr = path[2].get(); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(b_step_ptr), + UniquePtrHolds(c_step_ptr), + UniquePtrHolds(a_step_ptr))); + + ASSERT_OK_AND_ASSIGN(ExecutionPath extracted, context.ExtractSubplan(b)); + + EXPECT_THAT(extracted, ElementsAre(UniquePtrHolds(b_step_ptr))); + // Check that ownership was passed. + EXPECT_NE(extracted[0], path[0]); +} + +TEST_F(PlannerContextTest, ExtractPlanFailsOnUnfinishedNode) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + // Mark a incomplete. + tree[&a].range_len = -1; + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + EXPECT_THAT(context.ExtractSubplan(a), StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(PlannerContextTest, ExtractFailsOnReplacedNode) { + Expr a; + Expr b; + Expr c; + PlannerContext::ProgramTree tree; + + ASSERT_OK_AND_ASSIGN(ExecutionPath path, InitSimpleTree(a, b, c, tree)); + + PlannerContext context(resolver_, type_registry_, options_, builder_warnings_, + path, tree); + + ASSERT_OK(context.ReplaceSubplan(a, {})); + + EXPECT_THAT(context.ExtractSubplan(b), StatusIs(absl::StatusCode::kInternal)); +} + TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { Expr a; Expr b; diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc new file mode 100644 index 000000000..a53475904 --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -0,0 +1,180 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/regex_precompilation_optimization.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "base/ast_internal.h" +#include "base/builtins.h" +#include "base/internal/ast_impl.h" +#include "base/values/string_value.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/compiler_constant_step.h" +#include "eval/eval/regex_match_step.h" +#include "internal/casts.h" +#include "internal/rtti.h" + +namespace google::api::expr::runtime { +namespace { + +using cel::ast::internal::AstImpl; +using cel::ast::internal::Call; +using cel::ast::internal::Expr; +using cel::ast::internal::Reference; +using cel::internal::down_cast; +using cel::internal::TypeId; + +using ReferenceMap = absl::flat_hash_map; + +bool IsFunctionOverload( + const Expr& expr, absl::string_view function, absl::string_view overload, + size_t arity, + const absl::flat_hash_map& + reference_map) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call_expr = expr.call_expr(); + if (call_expr.function() != function) { + return false; + } + if (call_expr.args().size() + (call_expr.has_target() ? 1 : 0) != arity) { + return false; + } + auto reference = reference_map.find(expr.id()); + if (reference != reference_map.end() && + reference->second.overload_id().size() == 1 && + reference->second.overload_id().front() == overload) { + return true; + } + return false; +} + +// Abstraction for deduplicating regular expressions over the course of a single +// create expression call. Should not be used during evaluation. Uses +// std::shared_ptr and std::weak_ptr. +class RegexProgramBuilder final { + public: + explicit RegexProgramBuilder(int max_program_size) + : max_program_size_(max_program_size) {} + + absl::StatusOr> BuildRegexProgram( + std::string pattern) { + auto existing = programs_.find(pattern); + if (existing != programs_.end()) { + if (auto program = existing->second.lock(); program) { + return program; + } + programs_.erase(existing); + } + auto program = std::make_shared(pattern); + if (max_program_size_ > 0 && program->ProgramSize() > max_program_size_) { + return absl::InvalidArgumentError("exceeded RE2 max program size"); + } + if (!program->ok()) { + return absl::InvalidArgumentError("invalid_argument"); + } + programs_.insert({std::move(pattern), program}); + return program; + } + + private: + const int max_program_size_; + absl::flat_hash_map> programs_; +}; + +class RegexPrecompilationOptimization : public ProgramOptimizer { + public: + explicit RegexPrecompilationOptimization(const ReferenceMap& reference_map, + int regex_max_program_size) + : reference_map_(reference_map), + regex_program_builder_(regex_max_program_size) {} + + absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { + return absl::OkStatus(); + } + + absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override { + // Do not consider parse-only expressions. + if (reference_map_.empty()) { + return absl::OkStatus(); + } + + // Check that this is the correct matches overload instead of a user defined + // overload. + if (!IsFunctionOverload(node, cel::builtin::kRegexMatch, "matches_string", + 2, reference_map_)) { + return absl::OkStatus(); + } + + const Call& call_expr = node.call_expr(); + const Expr& pattern_expr = call_expr.args().back(); + + absl::optional pattern = + GetConstantString(context, pattern_expr); + if (!pattern.has_value()) { + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(auto program, regex_program_builder_.BuildRegexProgram( + std::move(pattern).value())); + + const Expr& subject_expr = + call_expr.has_target() ? call_expr.target() : call_expr.args().front(); + CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, + context.ExtractSubplan(subject_expr)); + CEL_ASSIGN_OR_RETURN(new_plan.emplace_back(), + CreateRegexMatchStep(std::move(program), node.id())); + + return context.ReplaceSubplan(node, std::move(new_plan)); + } + + private: + absl::optional GetConstantString( + PlannerContext& context, const cel::ast::internal::Expr& expr) const { + if (expr.has_const_expr() && expr.const_expr().has_string_value()) { + return expr.const_expr().string_value(); + } + + ExecutionPathView re_plan = context.GetSubplan(expr); + if (re_plan.size() == 1 && + re_plan[0]->TypeId() == TypeId()) { + const auto& constant = + down_cast(*re_plan[0]); + if (constant.value()->Is()) { + return constant.value()->As().ToString(); + } + } + + return absl::nullopt; + } + + const ReferenceMap& reference_map_; + RegexProgramBuilder regex_program_builder_; +}; + +} // namespace + +ProgramOptimizerFactory CreateRegexPrecompilationExtension( + int regex_max_program_size) { + return [=](PlannerContext& context, const AstImpl& ast) { + return std::make_unique( + ast.reference_map(), regex_max_program_size); + }; +} +} // namespace google::api::expr::runtime diff --git a/eval/compiler/regex_precompilation_optimization.h b/eval/compiler/regex_precompilation_optimization.h new file mode 100644 index 000000000..7b15d9aae --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization.h @@ -0,0 +1,29 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ + +#include "eval/compiler/flat_expr_builder_extensions.h" + +namespace google::api::expr::runtime { + +// Create a new extension for the FlatExprBuilder that precompiles constant +// regular expressions used in the standard 'Match' function. +ProgramOptimizerFactory CreateRegexPrecompilationExtension( + int regex_max_program_size); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc new file mode 100644 index 000000000..973c28ecd --- /dev/null +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -0,0 +1,217 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/regex_precompilation_optimization.h" + +#include +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "base/ast_internal.h" +#include "base/internal/ast_impl.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_options.h" +#include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" + +namespace google::api::expr::runtime { +namespace { + +using cel::ast::internal::CheckedExpr; +using google::api::expr::parser::Parse; + +namespace exprpb = google::api::expr::v1alpha1; + +class RegexPrecompilationExtensionTest : public testing::Test { + public: + RegexPrecompilationExtensionTest() + : type_registry_(*builder_.GetTypeRegistry()), + function_registry_(*builder_.GetRegistry()), + resolver_("", function_registry_.InternalGetRegistry(), + &type_registry_) { + options_.enable_regex = true; + options_.regex_max_program_size = 100; + options_.enable_regex_precompilation = true; + runtime_options_ = ConvertToRuntimeOptions(options_); + } + + void SetUp() override { + ASSERT_OK(RegisterBuiltinFunctions(&function_registry_, options_)); + } + + protected: + FlatExprBuilder builder_; + CelTypeRegistry& type_registry_; + CelFunctionRegistry& function_registry_; + InterpreterOptions options_; + cel::RuntimeOptions runtime_options_; + Resolver resolver_; + BuilderWarnings builder_warnings_; +}; + +TEST_F(RegexPrecompilationExtensionTest, SmokeTest) { + ProgramOptimizerFactory factory = + CreateRegexPrecompilationExtension(options_.regex_max_program_size); + ExecutionPath path; + PlannerContext::ProgramTree program_tree; + CheckedExpr expr; + cel::ast::internal::AstImpl ast_impl(std::move(expr)); + PlannerContext context(resolver_, type_registry_, runtime_options_, + builder_warnings_, path, program_tree); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr optimizer, + factory(context, ast_impl)); +} + +MATCHER_P(ExpressionPlanSizeIs, size, "") { + // This is brittle, but the most direct way to test that the plan + // was optimized. + const std::unique_ptr& plan = arg; + + const CelExpressionFlatImpl* impl = + dynamic_cast(plan.get()); + + if (impl == nullptr) return false; + *result_listener << "got size " << impl->path().size(); + return impl->path().size() == size; +} + +TEST_F(RegexPrecompilationExtensionTest, OptimizeableExpression) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(2)); +} + +TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeParsedExpr) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr expr, + Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr plan, + builder_.CreateExpression(&expr.expr(), &expr.source_info())); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(3)); +} + +TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(input_re)")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(3)); +} + +TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches('abc' + 'def')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(5)) << expr.DebugString(); +} + +class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { + public: + RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { + // TODO(issues/5): This applies to either version of const folding. + // Update when default is changed to new version. + builder_.set_constant_folding(true, &arena_); + } + + protected: + google::protobuf::Arena arena_; +}; + +TEST_F(RegexConstFoldInteropTest, StringConstantOptimizeable) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches('abc' + 'def')")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(2)) << expr.DebugString(); +} + +TEST_F(RegexConstFoldInteropTest, WrongTypeNotOptimized) { + builder_.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options_.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, + Parse("input.matches(123 + 456)")); + + // Fake reference information for the matches call. + exprpb::CheckedExpr expr; + expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); + expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); + (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder_.CreateExpression(&expr)); + + EXPECT_THAT(plan, ExpressionPlanSizeIs(3)) << expr.DebugString(); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 9af2a3726..2340b66df 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -318,6 +318,8 @@ class CelExpressionFlatImpl : public CelExpression { CelEvaluationState* state, CelEvaluationListener callback) const override; + const ExecutionPath& path() const { return path_; } + private: const ExecutionPath path_; const CelTypeRegistry& type_registry_; diff --git a/eval/eval/regex_match_step.cc b/eval/eval/regex_match_step.cc index 085d66c86..d41d243b4 100644 --- a/eval/eval/regex_match_step.cc +++ b/eval/eval/regex_match_step.cc @@ -29,9 +29,8 @@ namespace { using ::cel::interop_internal::CreateBoolValue; -inline constexpr int kNumRegexMatchArguments = 2; +inline constexpr int kNumRegexMatchArguments = 1; inline constexpr size_t kRegexMatchStepSubject = 0; -inline constexpr size_t kRegexMatchStepPattern = 1; class RegexMatchStep final : public ExpressionStepBase { public: @@ -47,22 +46,11 @@ class RegexMatchStep final : public ExpressionStepBase { } auto input_args = frame->value_stack().GetSpan(kNumRegexMatchArguments); const auto& subject = input_args[kRegexMatchStepSubject]; - const auto& pattern = input_args[kRegexMatchStepPattern]; if (!subject->Is()) { return absl::Status(absl::StatusCode::kInternal, "First argument for regular " "expression match must be a string"); } - if (!pattern->Is()) { - return absl::Status(absl::StatusCode::kInternal, - "Second argument for regular " - "expression match must be a string"); - } - if (!pattern.As()->Equals(re2_->pattern())) { - return absl::Status( - absl::StatusCode::kInternal, - "Original pattern and supplied pattern are not the same"); - } bool match = subject.As()->Matches(*re2_); frame->value_stack().Pop(kNumRegexMatchArguments); frame->value_stack().Push(CreateBoolValue(match)); diff --git a/eval/public/BUILD b/eval/public/BUILD index b030b86c5..874810ea4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -1250,6 +1250,7 @@ cc_library( "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", "//eval/compiler:qualified_reference_resolver", + "//eval/compiler:regex_precompilation_optimization", "//eval/public/structs:legacy_type_provider", "//runtime:runtime_options", "@com_google_absl//absl/status", diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 5d4bb5de7..f22298ea1 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -24,6 +24,7 @@ #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/regex_precompilation_optimization.h" #include "eval/public/cel_options.h" #include "runtime/runtime_options.h" @@ -50,7 +51,6 @@ std::unique_ptr CreatePortableExprBuilder( // many build dependencies by default. builder->set_enable_comprehension_vulnerability_check( options.enable_comprehension_vulnerability_check); - builder->set_enable_regex_precompilation(options.enable_regex_precompilation); if (options.constant_folding && options.enable_updated_constant_folding) { builder->AddProgramOptimizer( @@ -61,6 +61,11 @@ std::unique_ptr CreatePortableExprBuilder( options.constant_arena); } + if (options.enable_regex_precompilation) { + builder->AddProgramOptimizer( + CreateRegexPrecompilationExtension(options.regex_max_program_size)); + } + return builder; } From 006ad0d1393d6eece214a9b125989542aef65fce Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 16 May 2023 20:47:12 +0000 Subject: [PATCH 282/303] Improve coverage a little bit PiperOrigin-RevId: 532563089 --- extensions/protobuf/struct_value.cc | 104 ++++--- extensions/protobuf/struct_value_test.cc | 378 ++++++++++++++++++++++- 2 files changed, 417 insertions(+), 65 deletions(-) diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index ee70b8ab2..da0949afd 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -1602,55 +1602,61 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { case Kind::kAny: return ProtoValue::Create(context.value_factory(), proto_value.GetMessageValue()); - case Kind::kBool: { - // google.protobuf.BoolValue, mapped to CEL primitive bool type for - // map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBoolValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateBoolValue(wrapped); - } - case Kind::kBytes: { - // google.protobuf.BytesValue, mapped to CEL primitive bytes type - // for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapBytesValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateBytesValue(std::move(wrapped)); - } - case Kind::kDouble: { - // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive - // double type for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapDoubleValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateDoubleValue(wrapped); - } - case Kind::kInt: { - // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive - // int type for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapIntValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateIntValue(wrapped); - } - case Kind::kString: { - // google.protobuf.StringValue, mapped to CEL primitive bytes type - // for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapStringValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateUncheckedStringValue( - std::move(wrapped)); - } - case Kind::kUint: { - // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL - // primitive uint type for map values. - CEL_ASSIGN_OR_RETURN(auto wrapped, - protobuf_internal::UnwrapUIntValueProto( - proto_value.GetMessageValue())); - return context.value_factory().CreateUintValue(wrapped); - } + case Kind::kWrapper: + switch (type->As().wrapped()->kind()) { + case Kind::kBool: { + // google.protobuf.BoolValue, mapped to CEL primitive bool type + // for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBoolValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateBoolValue(wrapped); + } + case Kind::kBytes: { + // google.protobuf.BytesValue, mapped to CEL primitive bytes + // type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapBytesValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateBytesValue( + std::move(wrapped)); + } + case Kind::kDouble: { + // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL + // primitive double type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapDoubleValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateDoubleValue(wrapped); + } + case Kind::kInt: { + // google.protobuf.{Int32Value,Int64Value}, mapped to CEL + // primitive int type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapIntValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateIntValue(wrapped); + } + case Kind::kString: { + // google.protobuf.StringValue, mapped to CEL primitive bytes + // type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapStringValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateUncheckedStringValue( + std::move(wrapped)); + } + case Kind::kUint: { + // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL + // primitive uint type for map values. + CEL_ASSIGN_OR_RETURN(auto wrapped, + protobuf_internal::UnwrapUIntValueProto( + proto_value.GetMessageValue())); + return context.value_factory().CreateUintValue(wrapped); + } + default: + ABSL_UNREACHABLE(); + } case Kind::kStruct: return context.value_factory() .CreateBorrowedStructValue< diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index fe15dd5da..e45a877c8 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -762,6 +762,7 @@ void TestGetWrapperFieldImpl( absl::FunctionRef>( const Handle&, const StructValue::GetFieldContext&)> get_field, + absl::string_view debug_string, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -777,6 +778,7 @@ void TestGetWrapperFieldImpl( get_field(value_without, StructValue::GetFieldContext(value_factory) .set_unbox_null_wrapper_types(true))); EXPECT_TRUE(field->Is()); + EXPECT_EQ(field->DebugString(), "null"); ASSERT_OK_AND_ASSIGN( field, get_field(value_without, StructValue::GetFieldContext(value_factory) @@ -788,11 +790,13 @@ void TestGetWrapperFieldImpl( ASSERT_OK_AND_ASSIGN( field, get_field(value_with, StructValue::GetFieldContext(value_factory))); + EXPECT_EQ(field->DebugString(), debug_string); ASSERT_NO_FATAL_FAILURE(set_field_tester(value_factory, field)); } void TestGetWrapperFieldByName( MemoryManager& memory_manager, absl::string_view name, + absl::string_view debug_string, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -803,11 +807,12 @@ void TestGetWrapperFieldByName( const StructValue::GetFieldContext& context) { return value->GetFieldByName(context, name); }, - unset_field_tester, test_message_maker, set_field_tester); + debug_string, unset_field_tester, test_message_maker, set_field_tester); } void TestGetWrapperFieldByNumber( MemoryManager& memory_manager, int64_t number, + absl::string_view debug_string, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> @@ -818,29 +823,33 @@ void TestGetWrapperFieldByNumber( const StructValue::GetFieldContext& context) { return value->GetFieldByNumber(context, number); }, - unset_field_tester, test_message_maker, set_field_tester); + debug_string, unset_field_tester, test_message_maker, set_field_tester); } void TestGetWrapperField( MemoryManager& memory_manager, absl::string_view name, + absl::string_view debug_string, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> set_field_tester) { - TestGetWrapperFieldByName(memory_manager, name, unset_field_tester, - test_message_maker, set_field_tester); + TestGetWrapperFieldByName(memory_manager, name, debug_string, + unset_field_tester, test_message_maker, + set_field_tester); TestGetWrapperFieldByNumber( - memory_manager, TestMessageFieldNameToNumber(name), unset_field_tester, - test_message_maker, set_field_tester); + memory_manager, TestMessageFieldNameToNumber(name), debug_string, + unset_field_tester, test_message_maker, set_field_tester); } void TestGetWrapperField( MemoryManager& memory_manager, absl::string_view name, + absl::string_view debug_string, absl::FunctionRef&)> unset_field_tester, absl::FunctionRef test_message_maker, absl::FunctionRef&)> set_field_tester) { TestGetWrapperField( - memory_manager, name, unset_field_tester, test_message_maker, + memory_manager, name, debug_string, unset_field_tester, + test_message_maker, [&](ValueFactory& value_factory, const Handle& field) { set_field_tester(field); }); @@ -851,7 +860,7 @@ void TestGetWrapperField( TEST_P(ProtoStructValueTest, BoolWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_bool_wrapper", + memory_manager(), "single_bool_wrapper", "true", [](const Handle& field) { EXPECT_FALSE(field.As()->value()); }, @@ -865,7 +874,7 @@ TEST_P(ProtoStructValueTest, BoolWrapperGetField) { TEST_P(ProtoStructValueTest, Int32WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_int32_wrapper", + memory_manager(), "single_int32_wrapper", "1", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -879,7 +888,7 @@ TEST_P(ProtoStructValueTest, Int32WrapperGetField) { TEST_P(ProtoStructValueTest, Int64WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_int64_wrapper", + memory_manager(), "single_int64_wrapper", "1", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -893,7 +902,7 @@ TEST_P(ProtoStructValueTest, Int64WrapperGetField) { TEST_P(ProtoStructValueTest, Uint32WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_uint32_wrapper", + memory_manager(), "single_uint32_wrapper", "1u", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -907,7 +916,7 @@ TEST_P(ProtoStructValueTest, Uint32WrapperGetField) { TEST_P(ProtoStructValueTest, Uint64WrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_uint64_wrapper", + memory_manager(), "single_uint64_wrapper", "1u", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -921,7 +930,7 @@ TEST_P(ProtoStructValueTest, Uint64WrapperGetField) { TEST_P(ProtoStructValueTest, FloatWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_float_wrapper", + memory_manager(), "single_float_wrapper", "1.0", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -935,7 +944,7 @@ TEST_P(ProtoStructValueTest, FloatWrapperGetField) { TEST_P(ProtoStructValueTest, DoubleWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_double_wrapper", + memory_manager(), "single_double_wrapper", "1.0", [](const Handle& field) { EXPECT_EQ(field.As()->value(), 0); }, @@ -949,7 +958,7 @@ TEST_P(ProtoStructValueTest, DoubleWrapperGetField) { TEST_P(ProtoStructValueTest, BytesWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_bytes_wrapper", + memory_manager(), "single_bytes_wrapper", "b\"foo\"", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -963,7 +972,7 @@ TEST_P(ProtoStructValueTest, BytesWrapperGetField) { TEST_P(ProtoStructValueTest, StringWrapperGetField) { TEST_GET_WRAPPER_FIELD( - memory_manager(), "single_string_wrapper", + memory_manager(), "single_string_wrapper", "\"foo\"", [](const Handle& field) { EXPECT_EQ(field.As()->ToString(), ""); }, @@ -1642,6 +1651,83 @@ TEST_P(ProtoStructValueTest, AnyListGetField) { }); } +void TestGetMapFieldImpl( + MemoryManager& memory_manager, + absl::FunctionRef>( + const Handle&, const StructValue::GetFieldContext&)> + get_field, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TypeFactory type_factory(memory_manager); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto value_without, + ProtoValue::Create(value_factory, CreateTestMessage())); + ASSERT_OK_AND_ASSIGN( + auto field, + get_field(value_without, StructValue::GetFieldContext(value_factory))); + ASSERT_TRUE(field->Is()); + ASSERT_NO_FATAL_FAILURE(unset_field_tester(field.As())); + ASSERT_OK_AND_ASSIGN( + auto value_with, + ProtoValue::Create(value_factory, CreateTestMessage(test_message_maker))); + ASSERT_OK_AND_ASSIGN( + field, + get_field(value_with, StructValue::GetFieldContext(value_factory))); + ASSERT_TRUE(field->Is()); + ASSERT_NO_FATAL_FAILURE( + set_field_tester(value_factory, field.As())); +} + +void TestGetMapFieldByName( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetMapFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByName(context, name); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetMapFieldByNumber( + MemoryManager& memory_manager, int64_t number, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetMapFieldImpl( + memory_manager, + [&](const Handle& value, + const StructValue::GetFieldContext& context) { + return value->GetFieldByNumber(context, number); + }, + unset_field_tester, test_message_maker, set_field_tester); +} + +void TestGetMapField( + MemoryManager& memory_manager, absl::string_view name, + absl::FunctionRef&)> unset_field_tester, + absl::FunctionRef test_message_maker, + absl::FunctionRef&)> + set_field_tester) { + TestGetMapFieldByName(memory_manager, name, unset_field_tester, + test_message_maker, set_field_tester); + TestGetMapFieldByNumber(memory_manager, TestMessageFieldNameToNumber(name), + unset_field_tester, test_message_maker, + set_field_tester); +} + +#define TEST_GET_MAP_FIELD(...) \ + ASSERT_NO_FATAL_FAILURE(TestGetMapField(__VA_ARGS__)) + template void TestMapHasField(MemoryManager& memory_manager, absl::string_view map_field_name, @@ -2980,6 +3066,266 @@ TEST_P(ProtoStructValueTest, BoolMessageMapGetField) { std::make_pair(true, CreateTestNestedMessage(2)), nullptr); } +void EmptyMapTester(const Handle& field) { + EXPECT_TRUE(field->empty()); + EXPECT_EQ(field->size(), 0); + EXPECT_EQ(field->DebugString(), "{}"); +} + +TEST_P(ProtoStructValueTest, BoolStructMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_struct", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::Struct proto; + google::protobuf::Value value; + value.set_bool_value(false); + proto.mutable_fields()->insert({"foo", value}); + message.mutable_map_bool_struct()->insert({false, proto}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value).As()->size(), 1); + ASSERT_OK_AND_ASSIGN(auto subvalue, + (*value).As()->Get( + MapValue::GetContext(value_factory), + Must(value_factory.CreateStringValue("foo")))); + ASSERT_TRUE(subvalue); + ASSERT_TRUE((*subvalue)->Is()); + EXPECT_FALSE((*subvalue)->As().value()); + }); +} + +TEST_P(ProtoStructValueTest, BoolValueMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_value", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::Value value; + value.set_bool_value(true); + message.mutable_map_bool_value()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_TRUE((*value)->As().value()); + }); +} + +TEST_P(ProtoStructValueTest, BoolListValueMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_list_value", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::ListValue value; + value.add_values(); + message.mutable_map_bool_list_value()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_FALSE((*value)->As().empty()); + EXPECT_EQ((*value)->As().size(), 1); + ASSERT_OK_AND_ASSIGN(auto element, + (*value)->As().Get( + ListValue::GetContext(value_factory), 0)); + ASSERT_TRUE(element->Is()); + }); +} + +TEST_P(ProtoStructValueTest, BoolBoolWrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_bool_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::BoolValue value; + value.set_value(true); + message.mutable_map_bool_bool_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_TRUE((*value)->As().value()); + }); +} + +TEST_P(ProtoStructValueTest, BoolInt32WrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_int32_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::Int32Value value; + value.set_value(1); + message.mutable_map_bool_int32_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolInt64WrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_int64_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::Int64Value value; + value.set_value(1); + message.mutable_map_bool_int64_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolUInt32WrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_uint32_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::UInt32Value value; + value.set_value(1); + message.mutable_map_bool_uint32_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolUInt64WrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_uint64_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::UInt64Value value; + value.set_value(1); + message.mutable_map_bool_uint64_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolFloatWrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_float_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::FloatValue value; + value.set_value(1); + message.mutable_map_bool_float_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolDoubleWrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_double_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::DoubleValue value; + value.set_value(1); + message.mutable_map_bool_double_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().value(), 1); + }); +} + +TEST_P(ProtoStructValueTest, BoolBytesWrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_bytes_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::BytesValue value; + value.set_value("foo"); + message.mutable_map_bool_bytes_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().ToString(), "foo"); + }); +} + +TEST_P(ProtoStructValueTest, BoolStringWrapperMapGetField) { + TEST_GET_MAP_FIELD( + memory_manager(), "map_bool_string_wrapper", EmptyMapTester, + [](TestAllTypes& message) { + google::protobuf::StringValue value; + value.set_value("foo"); + message.mutable_map_bool_string_wrapper()->insert({false, value}); + }, + [](ValueFactory& value_factory, const Handle& field) { + EXPECT_FALSE(field->empty()); + EXPECT_EQ(field->size(), 1); + ASSERT_OK_AND_ASSIGN(auto value, + field->Get(MapValue::GetContext(value_factory), + value_factory.CreateBoolValue(false))); + ASSERT_TRUE(value); + ASSERT_TRUE((*value)->Is()); + EXPECT_EQ((*value)->As().ToString(), "foo"); + }); +} + TEST_P(ProtoStructValueTest, Int32NullValueMapGetField) { TestMapGetField( memory_manager(), "map_int32_null_value", "{0: null, 1: null}", From 3ff5836ed8dadfb49ffb9822121a59bea7bc2838 Mon Sep 17 00:00:00 2001 From: kuat Date: Tue, 16 May 2023 21:28:37 +0000 Subject: [PATCH 283/303] OSS export PiperOrigin-RevId: 532574621 --- .bazelrc | 2 +- .bazelversion | 2 +- base/function_result.h | 2 +- base/value.h | 1 - base/values/null_value.h | 1 + base/values/string_value.cc | 1 - base/values/string_value.h | 3 +- bazel/abseil.patch | 42 ------------------- bazel/deps.bzl | 22 +++++----- eval/compiler/flat_expr_builder_test.cc | 1 + eval/internal/interop.cc | 4 +- eval/public/BUILD | 2 +- eval/public/ast_rewrite_native_test.cc | 1 + eval/public/cel_value.h | 4 +- .../container_function_registrar_test.cc | 2 + .../structs/cel_proto_lite_wrap_util_test.cc | 14 +++---- eval/tests/BUILD | 1 + eval/tests/benchmark_test.cc | 1 + .../expression_builder_benchmark_test.cc | 1 + eval/tests/memory_safety_test.cc | 4 +- extensions/protobuf/BUILD | 1 + extensions/protobuf/struct_value_test.cc | 18 ++++---- extensions/protobuf/type_test.cc | 4 +- extensions/protobuf/value_test.cc | 3 +- testutil/util.h | 14 ++++++- 25 files changed, 65 insertions(+), 86 deletions(-) delete mode 100644 bazel/abseil.patch diff --git a/.bazelrc b/.bazelrc index 2521e741d..d4fe870a3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,4 +1,4 @@ -build --cxxopt=-std=c++17 +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 build --cxxopt=-fsized-deallocation # Enable matchers in googletest diff --git a/.bazelversion b/.bazelversion index 0062ac971..dfda3e0b4 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.0.0 +6.1.0 diff --git a/base/function_result.h b/base/function_result.h index da6e164ea..fafab8899 100644 --- a/base/function_result.h +++ b/base/function_result.h @@ -26,7 +26,7 @@ namespace cel { // allows for lazy evaluation of expensive functions. class FunctionResult final { public: - FunctionResult() = default; + FunctionResult() = delete; FunctionResult(const FunctionResult&) = default; FunctionResult(FunctionResult&&) = default; FunctionResult& operator=(const FunctionResult&) = default; diff --git a/base/value.h b/base/value.h index 21ba4ea5b..3cf89933b 100644 --- a/base/value.h +++ b/base/value.h @@ -325,7 +325,6 @@ CEL_INTERNAL_VALUE_DECL(Value); template \ friend struct base_internal::AnyData; \ \ - value_class() = default; \ value_class(const value_class&) = default; \ value_class(value_class&&) = default; \ value_class& operator=(const value_class&) = default; \ diff --git a/base/values/null_value.h b/base/values/null_value.h index d2f37da78..bedd0c0ea 100644 --- a/base/values/null_value.h +++ b/base/values/null_value.h @@ -48,6 +48,7 @@ class NullValue final : public base_internal::SimpleValue { using Base::type; private: + NullValue() = default; CEL_INTERNAL_SIMPLE_VALUE_MEMBERS(NullValue); }; diff --git a/base/values/string_value.cc b/base/values/string_value.cc index ae8808bc5..77eb05ebc 100644 --- a/base/values/string_value.cc +++ b/base/values/string_value.cc @@ -22,7 +22,6 @@ #include "base/types/string_type.h" #include "internal/strings.h" #include "internal/utf8.h" -#include "re2/re2.h" namespace cel { diff --git a/base/values/string_value.h b/base/values/string_value.h index 426f3b44a..f5eb1a419 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -30,8 +30,7 @@ #include "base/type.h" #include "base/types/string_type.h" #include "base/value.h" - -class RE2; +#include "re2/re2.h" namespace cel { diff --git a/bazel/abseil.patch b/bazel/abseil.patch deleted file mode 100644 index d52556466..000000000 --- a/bazel/abseil.patch +++ /dev/null @@ -1,42 +0,0 @@ -# Force internal versions of std classes per -# https://abseil.io/docs/cpp/guides/options -diff --git a/absl/base/options.h b/absl/base/options.h -index 230bf1e..6e1b9e5 100644 ---- a/absl/base/options.h -+++ b/absl/base/options.h -@@ -100,7 +100,7 @@ - // User code should not inspect this macro. To check in the preprocessor if - // absl::any is a typedef of std::any, use the feature macro ABSL_USES_STD_ANY. - --#define ABSL_OPTION_USE_STD_ANY 2 -+#define ABSL_OPTION_USE_STD_ANY 0 - - - // ABSL_OPTION_USE_STD_OPTIONAL -@@ -127,7 +127,7 @@ - // absl::optional is a typedef of std::optional, use the feature macro - // ABSL_USES_STD_OPTIONAL. - --#define ABSL_OPTION_USE_STD_OPTIONAL 2 -+#define ABSL_OPTION_USE_STD_OPTIONAL 0 - - - // ABSL_OPTION_USE_STD_STRING_VIEW -@@ -154,7 +154,7 @@ - // absl::string_view is a typedef of std::string_view, use the feature macro - // ABSL_USES_STD_STRING_VIEW. - --#define ABSL_OPTION_USE_STD_STRING_VIEW 2 -+#define ABSL_OPTION_USE_STD_STRING_VIEW 0 - - // ABSL_OPTION_USE_STD_VARIANT - // -@@ -180,7 +180,7 @@ - // absl::variant is a typedef of std::variant, use the feature macro - // ABSL_USES_STD_VARIANT. - --#define ABSL_OPTION_USE_STD_VARIANT 2 -+#define ABSL_OPTION_USE_STD_VARIANT 0 - - - // ABSL_OPTION_USE_INLINE_NAMESPACE diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 814a2788c..b7d20da3e 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -7,16 +7,14 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") def base_deps(): """Base evaluator and test dependencies.""" - # 2022-09-08 - ABSL_SHA1 = "518984e432e0597fd4e66a9c52148e8dec2bb46a" - ABSL_SHA256 = "97e721f8f2a49c507190821a76cdf1c8b659eb49728e6dcf527670f943b2ba60" + # 2023-05-15 + ABSL_SHA1 = "3aa3377ef66e6388ed19fd7c862bf0dc3a3630e0" + ABSL_SHA256 = "91b144618b8d608764b556d56eba07d4a6055429807e8d8fca0566cc5b66485e" http_archive( name = "com_google_absl", urls = ["https://github.com/abseil/abseil-cpp/archive/" + ABSL_SHA1 + ".zip"], strip_prefix = "abseil-cpp-" + ABSL_SHA1, sha256 = ABSL_SHA256, - patches = ["//bazel:abseil.patch"], - patch_args = ["-p1"], ) # v1.11.0 @@ -39,9 +37,9 @@ def base_deps(): sha256 = BENCHMARK_SHA256, ) - # 2021-09-01 - RE2_SHA1 = "8e08f47b11b413302749c0d8b17a1c94777495d5" - RE2_SHA256 = "d635a3353bb8ffc33b0779c97c1c9d6f2dbdda286106a73bbcf498f66edacd74" + # 2022-02-18 + RE2_SHA1 = "f6834581a8913c03d087de1e5d5b479f8a870400" + RE2_SHA256 = "ef7f29b79f9e3a8e4030ea2a0f71a66bd99aa0376fe641d86d47d6129c7f5aed" http_archive( name = "com_googlesource_code_re2", urls = ["https://github.com/google/re2/archive/" + RE2_SHA1 + ".zip"], @@ -49,8 +47,8 @@ def base_deps(): sha256 = RE2_SHA256, ) - PROTOBUF_VERSION = "3.21.1" - PROTOBUF_SHA = "a295dd3b9551d3e2749a9969583dea110c6cdcc39d02088f7c7bb1100077e081" + PROTOBUF_VERSION = "23.0" + PROTOBUF_SHA = "b8faf8487cc364e5c2b47a9abd77512bc79a6389ea45392ca938ba7617eae877" http_archive( name = "com_google_protobuf", sha256 = PROTOBUF_SHA, @@ -124,10 +122,10 @@ def cel_spec_deps(): ], ) - CEL_SPEC_GIT_SHA = "2cfa4f6a2dd7cb101459f6a286a4920c7322649f" # 9/7/2022 + CEL_SPEC_GIT_SHA = "c8bbae9828aea503e17300affc7e0b7264a4983e" # 4/28/2023 http_archive( name = "com_google_cel_spec", - sha256 = "78bfc17821607919724b033f1ba6e636d0cdfe056363055f4ab7f46b19e6a184", + sha256 = "d19c06c91162b10c9d2a8f4799ce231ecfa100fb6f8258d767a56efcdfc9d46f", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2912fa087..e49f92fe9 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -28,6 +28,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/text_format.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index a126c322e..c2ae7eb51 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -578,7 +578,7 @@ Handle LegacyValueToModernValueOrDie( google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& value, bool unchecked) { auto modern_value = FromLegacyValue(arena, value, unchecked); - CHECK_OK(modern_value); // Crash OK + ABSL_CHECK_OK(modern_value); // Crash OK return std::move(modern_value).value(); } @@ -614,7 +614,7 @@ std::vector> LegacyValueToModernValueOrDie( google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( google::protobuf::Arena* arena, const Handle& value, bool unchecked) { auto legacy_value = ToLegacyValue(arena, value, unchecked); - CHECK_OK(legacy_value); // Crash OK + ABSL_CHECK_OK(legacy_value); // Crash OK return std::move(legacy_value).value(); } diff --git a/eval/public/BUILD b/eval/public/BUILD index 874810ea4..85c051fac 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -888,7 +888,7 @@ cc_test( "//extensions/protobuf:ast_converters", "//internal:testing", "//parser", - "//testutil:util", + "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/ast_rewrite_native_test.cc b/eval/public/ast_rewrite_native_test.cc index de4fc630c..e35cfcf71 100644 --- a/eval/public/ast_rewrite_native_test.cc +++ b/eval/public/ast_rewrite_native_test.cc @@ -16,6 +16,7 @@ #include +#include "google/protobuf/text_format.h" #include "eval/public/ast_visitor_native.h" #include "eval/public/source_position_native.h" #include "extensions/protobuf/ast_converters.h" diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index d824c5f11..031585652 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -547,7 +547,7 @@ class CelList { virtual ~CelList() {} private: - friend class cel::interop_internal::CelListAccess; + friend struct cel::interop_internal::CelListAccess; virtual cel::internal::TypeInfo TypeId() const { return cel::internal::TypeInfo(); @@ -625,7 +625,7 @@ class CelMap { virtual ~CelMap() {} private: - friend class cel::interop_internal::CelMapAccess; + friend struct cel::interop_internal::CelMapAccess; virtual cel::internal::TypeInfo TypeId() const { return cel::internal::TypeInfo(); diff --git a/eval/public/container_function_registrar_test.cc b/eval/public/container_function_registrar_test.cc index 675254974..0e782f45c 100644 --- a/eval/public/container_function_registrar_test.cc +++ b/eval/public/container_function_registrar_test.cc @@ -30,6 +30,8 @@ namespace google::api::expr::runtime { namespace { +using google::api::expr::v1alpha1::Expr; +using google::api::expr::v1alpha1::SourceInfo; using testing::ValuesIn; struct TestCase { diff --git a/eval/public/structs/cel_proto_lite_wrap_util_test.cc b/eval/public/structs/cel_proto_lite_wrap_util_test.cc index c6ade7124..08590cc48 100644 --- a/eval/public/structs/cel_proto_lite_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_lite_wrap_util_test.cc @@ -47,9 +47,9 @@ namespace google::api::expr::runtime::internal { namespace { using testing::Eq; -using testing::EqualsProto; using testing::UnorderedPointwise; using cel::internal::StatusIs; +using testutil::EqualsProto; using google::protobuf::Duration; using google::protobuf::ListValue; @@ -109,7 +109,7 @@ class ProtobufDescriptorAnyPackingApis : public LegacyAnyPackingApis { absl::Status Pack(const google::protobuf::MessageLite* message, google::protobuf::Any& any_message) const override { const google::protobuf::Message* message_ptr = - down_cast(message); + cel::internal::down_cast(message); any_message.PackFrom(*message_ptr); return absl::OkStatus(); } @@ -164,7 +164,7 @@ class CelProtoWrapperTest : public ::testing::Test { EXPECT_OK(result); tested_message = *result; EXPECT_TRUE(tested_message != nullptr); - EXPECT_THAT(*tested_message, testutil::EqualsProto(message)); + EXPECT_THAT(*tested_message, EqualsProto(message)); // Test the same as above, but with allocated message. MessageType* created_message = Arena::CreateMessage(arena()); @@ -173,7 +173,7 @@ class CelProtoWrapperTest : public ::testing::Test { EXPECT_EQ(created_message, *result); created_message = *result; EXPECT_TRUE(created_message != nullptr); - EXPECT_THAT(*created_message, testutil::EqualsProto(message)); + EXPECT_THAT(*created_message, EqualsProto(message)); } template @@ -209,7 +209,7 @@ class CelProtoWrapperTest : public ::testing::Test { return; } EXPECT_TRUE(cel_value.IsMessage()); - EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + EXPECT_THAT(cel_value.MessageOrDie(), EqualsProto(*result)); } std::unique_ptr ReflectedCopy( @@ -260,7 +260,7 @@ TEST_F(CelProtoWrapperTest, TestDuration) { Duration out; auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); EXPECT_TRUE(status.ok()); - EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); + EXPECT_THAT(out, EqualsProto(msg_duration)); } // This test verifies CelValue support of Timestamp type. @@ -275,7 +275,7 @@ TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp out; auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); EXPECT_TRUE(status.ok()); - EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); + EXPECT_THAT(out, EqualsProto(msg_timestamp)); } // Dynamic Values test diff --git a/eval/tests/BUILD b/eval/tests/BUILD index beff303b5..cbe3b246a 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -115,6 +115,7 @@ cc_test( "//eval/public/testing:matchers", "//internal:testing", "//parser", + "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index 95b538d60..bd66af8aa 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -38,6 +38,7 @@ namespace runtime { namespace { using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::rpc::context::AttributeContext; diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index 4317b0e49..3dcc383e8 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -37,6 +37,7 @@ namespace google::api::expr::runtime { namespace { +using google::api::expr::v1alpha1::CheckedExpr; using google::api::expr::v1alpha1::ParsedExpr; enum BenchmarkParam : int { diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index d1655f33f..becd93fd1 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -33,13 +33,15 @@ #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" +#include "testutil/util.h" namespace google::api::expr::runtime { namespace { +using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::rpc::context::AttributeContext; -using testing::EqualsProto; using cel::internal::IsOkAndHolds; +using testutil::EqualsProto; struct TestCase { std::string name; diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 785bef23c..7fe23ec5b 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -195,6 +195,7 @@ cc_test( "//extensions/protobuf/internal:descriptors", "//extensions/protobuf/internal:testing", "//internal:testing", + "//testutil:util", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", diff --git a/extensions/protobuf/struct_value_test.cc b/extensions/protobuf/struct_value_test.cc index e45a877c8..3f0b28b7c 100644 --- a/extensions/protobuf/struct_value_test.cc +++ b/extensions/protobuf/struct_value_test.cc @@ -23,6 +23,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" #include "absl/functional/function_ref.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" @@ -38,6 +39,7 @@ #include "extensions/protobuf/type_provider.h" #include "extensions/protobuf/value.h" #include "internal/testing.h" +#include "testutil/util.h" #include "proto/test/v1/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -49,11 +51,11 @@ namespace { using FieldId = ::cel::extensions::ProtoStructType::FieldId; using ::cel_testing::ValueOf; +using google::api::expr::testutil::EqualsProto; using testing::Eq; -using testing::EqualsProto; using testing::Optional; -using testing::status::CanonicalStatusIs; using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; using TestAllTypes = ::google::api::expr::test::v1::proto3::TestAllTypes; using NullValueProto = ::google::protobuf::NullValue; @@ -1868,11 +1870,11 @@ void TestMapGetField(MemoryManager& memory_manager, EXPECT_THAT(field.As()->Get( MapValue::GetContext(value_factory), value_factory.CreateErrorValue(absl::CancelledError())), - CanonicalStatusIs(absl::StatusCode::kInvalidArgument)); + StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(field.As()->Has( MapValue::HasContext(), value_factory.CreateErrorValue(absl::CancelledError())), - CanonicalStatusIs(absl::StatusCode::kInvalidArgument)); + StatusIs(absl::StatusCode::kInvalidArgument)); ASSERT_OK_AND_ASSIGN( auto keys, field.As()->ListKeys(MapValue::ListKeysContext(value_factory))); @@ -1960,11 +1962,11 @@ void TestStringMapGetField(MemoryManager& memory_manager, EXPECT_THAT(field.As()->Get( MapValue::GetContext(value_factory), value_factory.CreateErrorValue(absl::CancelledError())), - CanonicalStatusIs(absl::StatusCode::kInvalidArgument)); + StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(field.As()->Has( MapValue::HasContext(), value_factory.CreateErrorValue(absl::CancelledError())), - CanonicalStatusIs(absl::StatusCode::kInvalidArgument)); + StatusIs(absl::StatusCode::kInvalidArgument)); ASSERT_OK_AND_ASSIGN( auto keys, field.As()->ListKeys(MapValue::ListKeysContext(value_factory))); @@ -4123,7 +4125,7 @@ TEST_P(ProtoStructValueTest, NewFieldIteratorIds) { actual_ids.insert(id); } EXPECT_THAT(iterator->NextId(StructValue::GetFieldContext(value_factory)), - CanonicalStatusIs(absl::StatusCode::kFailedPrecondition)); + StatusIs(absl::StatusCode::kFailedPrecondition)); std::set expected_ids = { FieldIdFactory::Make(13), FieldIdFactory::Make(1), FieldIdFactory::Make(2), FieldIdFactory::Make(3), @@ -4169,7 +4171,7 @@ TEST_P(ProtoStructValueTest, NewFieldIteratorValues) { actual_values.push_back(std::move(value)); } EXPECT_THAT(iterator->NextValue(StructValue::GetFieldContext(value_factory)), - CanonicalStatusIs(absl::StatusCode::kFailedPrecondition)); + StatusIs(absl::StatusCode::kFailedPrecondition)); // We cannot really test actual_types, as hand translating TestAllTypes would // be obnoxious. Otherwise we would simply be testing the same logic against // itself, which would not be useful. diff --git a/extensions/protobuf/type_test.cc b/extensions/protobuf/type_test.cc index 80338e839..e6b033a0d 100644 --- a/extensions/protobuf/type_test.cc +++ b/extensions/protobuf/type_test.cc @@ -28,8 +28,8 @@ namespace cel::extensions { namespace { using ::cel_testing::TypeIs; -using testing::status::CanonicalStatusIs; using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; using ProtoTypeTest = ProtoTest<>; @@ -95,7 +95,7 @@ TEST_P(ProtoTypeTest, ResolveNotFound) { TypeManager type_manager(type_factory, TypeProvider::Builtin()); EXPECT_THAT( ProtoType::Resolve(type_manager, *google::protobuf::Api::descriptor()), - CanonicalStatusIs(absl::StatusCode::kNotFound)); + StatusIs(absl::StatusCode::kNotFound)); } INSTANTIATE_TEST_SUITE_P(ProtoTypeTest, ProtoTypeTest, diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc index 6909b0c1c..90e21bb0c 100644 --- a/extensions/protobuf/value_test.cc +++ b/extensions/protobuf/value_test.cc @@ -31,6 +31,7 @@ #include "extensions/protobuf/struct_value.h" #include "extensions/protobuf/type_provider.h" #include "internal/testing.h" +#include "testutil/util.h" #include "proto/test/v1/proto3/test_all_types.pb.h" #include "google/protobuf/generated_enum_reflection.h" @@ -38,8 +39,8 @@ namespace cel::extensions { namespace { using ::cel_testing::ValueOf; +using google::api::expr::testutil::EqualsProto; using testing::Eq; -using testing::EqualsProto; using testing::Optional; using cel::internal::IsOkAndHolds; diff --git a/testutil/util.h b/testutil/util.h index 170c140b8..2871a0a46 100644 --- a/testutil/util.h +++ b/testutil/util.h @@ -37,16 +37,27 @@ class ProtoStringMatcher { : expected_(expected) {} explicit inline ProtoStringMatcher(const google::protobuf::Message& expected) - : expected_(expected.DebugString()) {} + : expected_(expected.DebugString()), + expected_bytes_(expected.SerializeAsString()) {} template bool MatchAndExplain(const Message& p, ::testing::MatchResultListener* /* listener */) const; + bool MatchAndExplain(const google::protobuf::Message& p, + ::testing::MatchResultListener* /* listener */) const { + return p.SerializeAsString() == expected_bytes_; + } + template bool MatchAndExplain(const Message* p, ::testing::MatchResultListener* /* listener */) const; + bool MatchAndExplain(const google::protobuf::MessageLite* p, + ::testing::MatchResultListener* /* listener */) const { + return p->SerializeAsString() == expected_bytes_; + } + inline void DescribeTo(::std::ostream* os) const { *os << expected_; } inline void DescribeNegationTo(::std::ostream* os) const { *os << "not equal to expected message: " << expected_; @@ -54,6 +65,7 @@ class ProtoStringMatcher { private: const std::string expected_; + const std::string expected_bytes_; }; // Polymorphic matcher to compare any two protos. From 680f6d8d59709e2f260c546514546ba3b059759b Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 16 May 2023 21:36:03 +0000 Subject: [PATCH 284/303] Miscellaneous cleanup PiperOrigin-RevId: 532576668 --- extensions/protobuf/struct_type.h | 2 +- extensions/protobuf/struct_value.h | 2 +- extensions/protobuf/value.cc | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index a0257a5dd..bbfed1810 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -43,7 +43,7 @@ class ProtoStructTypeFieldIterator; class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { public: static bool Is(const Type& type) { - return type.kind() == Kind::kStruct && + return CEL_STRUCT_TYPE_CLASS::Is(type) && cel::base_internal::GetStructTypeTypeId( static_cast(type)) == cel::internal::TypeId(); diff --git a/extensions/protobuf/struct_value.h b/extensions/protobuf/struct_value.h index 1d748dd64..72b2da396 100644 --- a/extensions/protobuf/struct_value.h +++ b/extensions/protobuf/struct_value.h @@ -64,7 +64,7 @@ class ProtoStructValue : public CEL_STRUCT_VALUE_CLASS { public: static bool Is(const Value& value) { - return value.kind() == Kind::kStruct && + return CEL_STRUCT_VALUE_CLASS::Is(value) && cel::base_internal::GetStructValueTypeId( static_cast(value)) == cel::internal::TypeId(); diff --git a/extensions/protobuf/value.cc b/extensions/protobuf/value.cc index f28f871ee..daf2c90f3 100644 --- a/extensions/protobuf/value.cc +++ b/extensions/protobuf/value.cc @@ -529,7 +529,7 @@ absl::StatusOr> ProtoValue::Create( absl::WrapUnique(value->release_struct_value())); default: return absl::InvalidArgumentError(absl::StrCat( - "unexpected google.protobuf.Value kind: %d", value->kind_case())); + "unexpected google.protobuf.Value kind: ", value->kind_case())); } } @@ -552,7 +552,7 @@ absl::StatusOr> ProtoValue::Create( return Create(value_factory, std::move(*value.mutable_struct_value())); default: return absl::InvalidArgumentError(absl::StrCat( - "unexpected google.protobuf.Value kind: %d", value.kind_case())); + "unexpected google.protobuf.Value kind: ", value.kind_case())); } } @@ -579,7 +579,7 @@ absl::StatusOr> ProtoValue::CreateBorrowed( value.struct_value()); default: return absl::InvalidArgumentError(absl::StrCat( - "unexpected google.protobuf.Value kind: %d", value.kind_case())); + "unexpected google.protobuf.Value kind: ", value.kind_case())); } } From fd5f75dffffa7909597004c8b0285cfc683267a2 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 17 May 2023 17:26:27 +0000 Subject: [PATCH 285/303] Simplify `ListValueBuilder` and `MapValueBuilder` PiperOrigin-RevId: 532830941 --- base/values/list_value_builder.h | 40 +- base/values/map_value_builder.h | 810 ++++--------------------------- 2 files changed, 99 insertions(+), 751 deletions(-) diff --git a/base/values/list_value_builder.h b/base/values/list_value_builder.h index b655d41a7..ae4d9d7d6 100644 --- a/base/values/list_value_builder.h +++ b/base/values/list_value_builder.h @@ -31,7 +31,7 @@ namespace cel { // Abstract interface for building ListValue. // -// ListValueBuilderInterface is not re-usable, once Build() is called the state +// ListValueBuilderInterface is not reusable, once Build() is called the state // of ListValueBuilderInterface is undefined. class ListValueBuilderInterface { public: @@ -39,9 +39,7 @@ class ListValueBuilderInterface { virtual std::string DebugString() const = 0; - virtual absl::Status Add(const Handle& value) = 0; - - virtual absl::Status Add(Handle&& value) = 0; + virtual absl::Status Add(Handle value) = 0; virtual size_t size() const = 0; @@ -112,20 +110,11 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { return out; } - absl::Status Add(const Handle& value) override { - return Add(value.As()); - } - - absl::Status Add(Handle&& value) override { - return Add(value.As()); + absl::Status Add(Handle value) override { + return Add(std::move(value).As()); } - absl::Status Add(const Handle& value) { - storage_.push_back(value); - return absl::OkStatus(); - } - - absl::Status Add(Handle&& value) { + absl::Status Add(Handle value) { storage_.push_back(std::move(value)); return absl::OkStatus(); } @@ -176,12 +165,7 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { return out; } - absl::Status Add(const Handle& value) override { - storage_.push_back(value); - return absl::OkStatus(); - } - - absl::Status Add(Handle&& value) override { + absl::Status Add(Handle value) override { storage_.push_back(std::move(value)); return absl::OkStatus(); } @@ -233,20 +217,14 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { return out; } - absl::Status Add(const Handle& value) override { - return Add(value.As()); - } - - absl::Status Add(Handle&& value) override { - return Add(value.As()); + absl::Status Add(Handle value) override { + return Add(std::move(value).As()); } absl::Status Add(const Handle& value) { return Add(value->value()); } - absl::Status Add(Handle&& value) { return Add(value->value()); } - absl::Status Add(U value) { - storage_.push_back(value); + storage_.push_back(std::move(value)); return absl::OkStatus(); } diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index 3521e77ca..931ad153d 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -15,12 +15,9 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUES_MAP_VALUE_BUILDER_H_ -#include -#include #include #include #include -#include #include "absl/base/attributes.h" #include "absl/base/macros.h" @@ -36,7 +33,7 @@ namespace cel { // Abstract interface for building MapValue. // -// MapValueBuilderInterface is not re-usable, once Build() is called the state +// MapValueBuilderInterface is not reusable, once Build() is called the state // of MapValueBuilderInterface is undefined. class MapValueBuilderInterface { public: @@ -46,33 +43,19 @@ class MapValueBuilderInterface { // Insert a new entry. Returns true if the key did not already exist and the // insertion was performed, false otherwise. - virtual absl::StatusOr Insert(const Handle& key, - const Handle& value) = 0; - virtual absl::StatusOr Insert(const Handle& key, - Handle&& value) = 0; - virtual absl::StatusOr Insert(Handle&& key, - const Handle& value) = 0; - virtual absl::StatusOr Insert(Handle&& key, - Handle&& value) = 0; + virtual absl::StatusOr Insert(Handle key, + Handle value) = 0; // Update an already existing entry. Returns true if the key already existed // and the update was performed, false otherwise. virtual absl::StatusOr Update(const Handle& key, - const Handle& value) = 0; - virtual absl::StatusOr Update(const Handle& key, - Handle&& value) = 0; + Handle value) = 0; // A combination of Insert and Update, where the entry is inserted if it // doesn't already exist or it is updated. Returns true if insertion occurred, // false otherwise. - virtual absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) = 0; - virtual absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) = 0; - virtual absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) = 0; - virtual absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) = 0; + virtual absl::StatusOr InsertOrUpdate(Handle key, + Handle value) = 0; // Returns whether the given key has been inserted. virtual bool Has(const Handle& key) const = 0; @@ -177,7 +160,7 @@ struct MapKeyEqualer> { }; // For MapValueBuilder we use a linked hash map to preserve insertion order. -// This mimics protobuf and ensures some reproducability, making testing easier. +// This mimics protobuf and ensures some reproducibility, making testing easier. // Implementation used by MapValueBuilder when both the key and value are // represented as Value and not some C++ primitive. @@ -517,63 +500,21 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, const Handle& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; - } - - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key).As(), std::move(value).As()); } - absl::StatusOr Insert(Handle&& key, Handle&& value) { + absl::StatusOr Insert(Handle key, Handle value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, const Handle& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; + Handle value) override { + return Update(key.As(), std::move(value).As()); } - absl::StatusOr Update(const Handle& key, Handle&& value) { + absl::StatusOr Update(const Handle& key, Handle value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -582,40 +523,12 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return storage_.insert_or_assign(key, value).second; - } - - absl::StatusOr InsertOrUpdate(const Handle& key, Handle&& value) { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, const Handle& value) { - return storage_.insert_or_assign(std::move(key), value).second; + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -683,65 +596,21 @@ class MapValueBuilderImpl return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key.As(), value); - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key.As(), std::move(value)); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(key.As(), value); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(key.As(), std::move(value)); - } - - absl::StatusOr Insert(const Handle& key, - const Handle& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key).As(), std::move(value)); } - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; - } - - absl::StatusOr Insert(Handle&& key, Handle&& value) { + absl::StatusOr Insert(Handle key, Handle value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key.As(), value); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { + Handle value) override { return Update(key.As(), std::move(value)); } - absl::StatusOr Update(const Handle& key, - const Handle& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; - } - - absl::StatusOr Update(const Handle& key, Handle&& value) { + absl::StatusOr Update(const Handle& key, Handle value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -750,42 +619,12 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), std::move(value)); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), std::move(value)); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return storage_.insert_or_assign(key, value).second; + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key).As(), std::move(value)); } - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) { - return storage_.insert_or_assign(std::move(key), value).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -853,65 +692,21 @@ class MapValueBuilderImpl return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key, value.As()); - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key, value.As()); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(std::move(key), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(std::move(key), value.As()); - } - - absl::StatusOr Insert(const Handle& key, - const Handle& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; - } - - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key), std::move(value).As()); } - absl::StatusOr Insert(Handle&& key, Handle&& value) { + absl::StatusOr Insert(Handle key, Handle value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key, value.As()); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { - return Update(key, value.As()); - } - - absl::StatusOr Update(const Handle& key, - const Handle& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; + Handle value) override { + return Update(key, std::move(value).As()); } - absl::StatusOr Update(const Handle& key, Handle&& value) { + absl::StatusOr Update(const Handle& key, Handle value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -920,42 +715,12 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key, value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key, value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(std::move(key), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(std::move(key), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return storage_.insert_or_assign(key, value).second; + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) { - return storage_.insert_or_assign(std::move(key), value).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1017,39 +782,13 @@ class MapValueBuilderImpl return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return storage_.insert(std::make_pair(key, std::move(value))).second; - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return storage_.insert(std::make_pair(std::move(key), value)).second; - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { + absl::StatusOr Insert(Handle key, Handle value) override { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { + Handle value) override { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -1058,23 +797,8 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return storage_.insert_or_assign(key, value).second; - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return storage_.insert_or_assign(std::move(key), value).second; - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1141,87 +865,29 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(key.As(), value.As()); + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key).As(), std::move(value).As()); } - absl::StatusOr Insert(const Handle& key, const Handle& value) { - return Insert(key->value(), value); - } - - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return Insert(key->value(), std::move(value)); - } - - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return Insert(key->value(), value); - } - - absl::StatusOr Insert(Handle&& key, Handle&& value) { + absl::StatusOr Insert(Handle key, Handle value) { return Insert(key->value(), std::move(value)); } - absl::StatusOr Insert(const K& key, const Handle& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const K& key, Handle&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; - } - - absl::StatusOr Insert(K&& key, const Handle& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; - } - - absl::StatusOr Insert(K&& key, Handle&& value) { + absl::StatusOr Insert(UK key, Handle value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, const Handle& value) { - return Update(key, value); - } - - absl::StatusOr Update(const Handle& key, Handle&& value) { - return Update(std::move(key), std::move(value)); + Handle value) override { + return Update(key.As(), std::move(value).As()); } - absl::StatusOr Update(const K& key, const Handle& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; + absl::StatusOr Update(const Handle& key, Handle value) { + return Update(key->value(), std::move(value)); } - absl::StatusOr Update(const K& key, Handle&& value) { + absl::StatusOr Update(const UK& key, Handle value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -1230,56 +896,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return InsertOrUpdate(key->value(), value); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, Handle&& value) { - return InsertOrUpdate(key->value(), std::move(value)); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, const Handle& value) { - return InsertOrUpdate(key->value(), value); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { return InsertOrUpdate(key->value(), std::move(value)); } - absl::StatusOr InsertOrUpdate(const K& key, const Handle& value) { - return storage_.insert_or_assign(key, value).second; - } - - absl::StatusOr InsertOrUpdate(const K& key, Handle&& value) { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(K&& key, const Handle& value) { - return storage_.insert_or_assign(std::move(key), value).second; - } - - absl::StatusOr InsertOrUpdate(K&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(UK key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1347,87 +973,29 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, const Handle& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(Handle&& key, Handle&& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(const Handle& key, const UV& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const Handle& key, UV&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key).As(), std::move(value).As()); } - absl::StatusOr Insert(Handle&& key, const UV& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; + absl::StatusOr Insert(Handle key, Handle value) { + return Insert(std::move(key), value->value()); } - absl::StatusOr Insert(Handle&& key, UV&& value) { + absl::StatusOr Insert(Handle key, UV value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, const Handle& value) { - return Update(key, value->value()); + Handle value) override { + return Update(std::move(key).As(), std::move(value).As()); } - absl::StatusOr Update(const Handle& key, Handle&& value) { + absl::StatusOr Update(const Handle& key, Handle value) { return Update(key, value->value()); } - absl::StatusOr Update(const Handle& key, const UV& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; - } - - absl::StatusOr Update(const Handle& key, UV&& value) { + absl::StatusOr Update(const Handle& key, UV value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -1436,61 +1004,17 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, Handle&& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, const Handle& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, const UV& value) { - return storage_.insert_or_assign(std::make_pair(key, value)).second; - } - - absl::StatusOr InsertOrUpdate(const Handle& key, UV&& value) { - return storage_.insert_or_assign(std::make_pair(key, std::move(value))) - .second; + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle&& key, const UV& value) { - return storage_.insert_or_assign(std::make_pair(std::move(key), value)) - .second; + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { + return InsertOrUpdate(std::move(key), value->value()); } - absl::StatusOr InsertOrUpdate(Handle&& key, UV&& value) { - return storage_ - .insert_or_assign(std::make_pair(std::move(key), std::move(value))) - .second; + absl::StatusOr InsertOrUpdate(Handle key, UV value) { + return storage_.insert_or_assign(std::move(key), std::move(value)).second; } bool Has(const Handle& key) const override { return Has(key.As()); } @@ -1558,135 +1082,45 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return out; } - absl::StatusOr Insert(const Handle& key, - const Handle& value) override { - return Insert(key.As(), value.As()); + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key).As(), std::move(value).As()); } - absl::StatusOr Insert(const Handle& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - const Handle& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(Handle&& key, - Handle&& value) override { - return Insert(key.As(), value.As()); - } - - absl::StatusOr Insert(const Handle& key, const Handle& value) { + absl::StatusOr Insert(Handle key, Handle value) { return Insert(key->value(), value->value()); } - absl::StatusOr Insert(const Handle& key, Handle&& value) { - return Insert(key->value(), value->value()); - } - - absl::StatusOr Insert(Handle&& key, const Handle& value) { - return Insert(key->value(), value->value()); - } - - absl::StatusOr Insert(Handle&& key, Handle&& value) { - return Insert(key->value(), value->value()); - } - - absl::StatusOr Insert(const Handle& key, const V& value) { - return Insert(key->value(), value); - } - - absl::StatusOr Insert(const Handle& key, V&& value) { - return Insert(key->value(), std::move(value)); - } - - absl::StatusOr Insert(Handle&& key, const V& value) { - return Insert(key->value(), value); - } - - absl::StatusOr Insert(Handle&& key, V&& value) { + absl::StatusOr Insert(Handle key, UV value) { return Insert(key->value(), std::move(value)); } - absl::StatusOr Insert(const K& key, const Handle& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(const K& key, Handle&& value) { - return Insert(key, value->value()); - } - - absl::StatusOr Insert(K&& key, const Handle& value) { + absl::StatusOr Insert(UK key, Handle value) { return Insert(std::move(key), value->value()); } - absl::StatusOr Insert(K&& key, Handle&& value) { - return Insert(std::move(key), value->value()); - } - - absl::StatusOr Insert(const UK& key, const UV& value) { - return storage_.insert(std::make_pair(key, value)).second; - } - - absl::StatusOr Insert(const UK& key, UV&& value) { - return storage_.insert(std::make_pair(key, std::move(value))).second; - } - - absl::StatusOr Insert(UK&& key, const UV& value) { - return storage_.insert(std::make_pair(std::move(key), value)).second; - } - - absl::StatusOr Insert(UK&& key, UV&& value) { + absl::StatusOr Insert(UK key, UV value) { return storage_.insert(std::make_pair(std::move(key), std::move(value))) .second; } absl::StatusOr Update(const Handle& key, - const Handle& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, - Handle&& value) override { - return Update(key.As(), value.As()); - } - - absl::StatusOr Update(const Handle& key, const Handle& value) { - return Update(key->value(), value->value()); + Handle value) override { + return Update(key.As(), std::move(value).As()); } - absl::StatusOr Update(const Handle& key, Handle&& value) { + absl::StatusOr Update(const Handle& key, Handle value) { return Update(key->value(), value->value()); } - absl::StatusOr Update(const Handle& key, const V& value) { - return Update(key->value(), value); - } - - absl::StatusOr Update(const Handle& key, V&& value) { + absl::StatusOr Update(const Handle& key, V value) { return Update(key->value(), std::move(value)); } - absl::StatusOr Update(const K& key, const Handle& value) { - return Update(key, value->value()); - } - - absl::StatusOr Update(const K& key, Handle&& value) { + absl::StatusOr Update(const UK& key, Handle value) { return Update(key, value->value()); } - absl::StatusOr Update(const UK& key, const UV& value) { - auto existing = storage_.find(key); - if (existing == storage_.end()) { - return false; - } - existing->second = value; - return true; - } - - absl::StatusOr Update(const UK& key, UV&& value) { + absl::StatusOr Update(const UK& key, UV value) { auto existing = storage_.find(key); if (existing == storage_.end()) { return false; @@ -1695,88 +1129,24 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - const Handle& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, - Handle&& value) override { - return InsertOrUpdate(key.As(), value.As()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, - const Handle& value) { - return InsertOrUpdate(key->value(), value->value()); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, Handle&& value) { - return InsertOrUpdate(key->value(), value->value()); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, const Handle& value) { - return InsertOrUpdate(key->value(), value->value()); + absl::StatusOr InsertOrUpdate(Handle key, + Handle value) override { + return InsertOrUpdate(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle&& key, Handle&& value) { + absl::StatusOr InsertOrUpdate(Handle key, Handle value) { return InsertOrUpdate(key->value(), value->value()); } - absl::StatusOr InsertOrUpdate(const Handle& key, const V& value) { - return InsertOrUpdate(key->value(), value); - } - - absl::StatusOr InsertOrUpdate(const Handle& key, V&& value) { + absl::StatusOr InsertOrUpdate(Handle key, UV value) { return InsertOrUpdate(key->value(), std::move(value)); } - absl::StatusOr InsertOrUpdate(Handle&& key, const V& value) { - return InsertOrUpdate(key->value(), value); - } - - absl::StatusOr InsertOrUpdate(Handle&& key, V&& value) { - return InsertOrUpdate(key->value(), std::move(value)); - } - - absl::StatusOr InsertOrUpdate(const K& key, const Handle& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(const K& key, Handle&& value) { - return InsertOrUpdate(key, value->value()); - } - - absl::StatusOr InsertOrUpdate(K&& key, const Handle& value) { + absl::StatusOr InsertOrUpdate(UK key, Handle value) { return InsertOrUpdate(std::move(key), value->value()); } - absl::StatusOr InsertOrUpdate(K&& key, Handle&& value) { - return InsertOrUpdate(std::move(key), value->value()); - } - - absl::StatusOr InsertOrUpdate(const UK& key, const UV& value) { - return storage_.insert_or_assign(key, value).second; - } - - absl::StatusOr InsertOrUpdate(const UK& key, UV&& value) { - return storage_.insert_or_assign(key, std::move(value)).second; - } - - absl::StatusOr InsertOrUpdate(UK&& key, const UV& value) { - return storage_.insert_or_assign(std::move(key), value).second; - } - - absl::StatusOr InsertOrUpdate(UK&& key, UV&& value) { + absl::StatusOr InsertOrUpdate(UK key, UV value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1784,7 +1154,7 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { bool Has(const Handle& key) const { return Has(key->value()); } - bool Has(UK key) const { return storage_.find(key) != storage_.end(); } + bool Has(const UK& key) const { return storage_.find(key) != storage_.end(); } size_t size() const override { return storage_.size(); } From 1838dd08e7a638f5ee76bca2ed7af08694e4b09c Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 17 May 2023 19:30:37 +0000 Subject: [PATCH 286/303] Remove `noexcept` unless required to avoid bugs in older compilers PiperOrigin-RevId: 532870452 --- base/handle.h | 52 +++++++++++++++---------------- base/internal/data.h | 2 +- base/memory.h | 9 +++--- base/owner.h | 14 ++++----- base/types/list_type.h | 2 +- base/types/map_type.h | 2 +- base/types/optional_type.h | 2 +- base/values/enum_value.h | 2 +- base/values/optional_value.h | 4 +-- extensions/protobuf/enum_type.h | 2 +- extensions/protobuf/struct_type.h | 2 +- 11 files changed, 45 insertions(+), 48 deletions(-) diff --git a/base/handle.h b/base/handle.h index 24c8c8b52..fdf990b70 100644 --- a/base/handle.h +++ b/base/handle.h @@ -40,37 +40,37 @@ class Handle final : private base_internal::HandlePolicy { // Default constructs the handle, setting it to an empty state. It is // undefined behavior to call any functions that attempt to dereference or // access `T` when in an empty state. - Handle() noexcept = default; + Handle() = default; - Handle(const Handle&) noexcept = default; + Handle(const Handle&) = default; template >> - Handle(const Handle& handle) noexcept : impl_(handle.impl_) {} // NOLINT + Handle(const Handle& handle) : impl_(handle.impl_) {} // NOLINT - Handle(Handle&&) noexcept = default; + Handle(Handle&&) = default; template >> - Handle(Handle&& handle) noexcept // NOLINT + Handle(Handle&& handle) // NOLINT : impl_(std::move(handle.impl_)) {} - ~Handle() noexcept = default; + ~Handle() = default; - Handle& operator=(const Handle&) noexcept = default; + Handle& operator=(const Handle&) = default; - Handle& operator=(Handle&&) noexcept = default; + Handle& operator=(Handle&&) = default; template std::enable_if_t, Handle&> // NOLINT - operator=(const Handle& handle) noexcept { + operator=(const Handle& handle) { impl_ = handle.impl_; return *this; } template std::enable_if_t, Handle&> // NOLINT - operator=(Handle&& handle) noexcept { + operator=(Handle&& handle) { impl_ = std::move(handle.impl_); return *this; } @@ -81,11 +81,11 @@ class Handle final : private base_internal::HandlePolicy { // Handle handle; // handle.As()->SubMethod(); template - std::enable_if_t< - std::disjunction_v, std::is_base_of, - std::is_same>, - Handle&> - As() & noexcept ABSL_MUST_USE_RESULT { + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + Handle&> + As() & ABSL_MUST_USE_RESULT { static_assert(std::is_same_v::Impl>, "Handle and Handle must have the same " "implementation type"); @@ -107,11 +107,11 @@ class Handle final : private base_internal::HandlePolicy { // Handle handle; // handle.As()->SubMethod(); template - std::enable_if_t< - std::disjunction_v, std::is_base_of, - std::is_same>, - Handle&&> - As() && noexcept ABSL_MUST_USE_RESULT { + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + Handle&&> + As() && ABSL_MUST_USE_RESULT { static_assert(std::is_same_v::Impl>, "Handle and Handle must have the same " "implementation type"); @@ -137,7 +137,7 @@ class Handle final : private base_internal::HandlePolicy { std::disjunction_v, std::is_base_of, std::is_same>, const Handle&> - As() const& noexcept ABSL_MUST_USE_RESULT { + As() const& ABSL_MUST_USE_RESULT { static_assert(std::is_same_v::Impl>, "Handle and Handle must have the same " "implementation type"); @@ -163,7 +163,7 @@ class Handle final : private base_internal::HandlePolicy { std::disjunction_v, std::is_base_of, std::is_same>, const Handle&&> - As() const&& noexcept ABSL_MUST_USE_RESULT { + As() const&& ABSL_MUST_USE_RESULT { static_assert(std::is_same_v::Impl>, "Handle and Handle must have the same " "implementation type"); @@ -179,20 +179,20 @@ class Handle final : private base_internal::HandlePolicy { return std::move(*reinterpret_cast*>(this)); } - T& operator*() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)) << "cannot dereference empty handle"; return static_cast(*impl_.get()); } - T* operator->() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)) << "cannot dereference empty handle"; return static_cast(impl_.get()); } // Tests whether the handle is not empty, returning false if it is empty. - explicit operator bool() const noexcept { return static_cast(impl_); } + explicit operator bool() const { return static_cast(impl_); } - friend void swap(Handle& lhs, Handle& rhs) noexcept { + friend void swap(Handle& lhs, Handle& rhs) { std::swap(lhs.impl_, rhs.impl_); } diff --git a/base/internal/data.h b/base/internal/data.h index b18a5e434..86d1d9621 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -189,7 +189,7 @@ class HeapData /* : public Data */ { // our destructor called. Subclasses should override this if they want their // destructor to be skippable, by default it is not. static bool IsDestructorSkippable( - const HeapData& data ABSL_ATTRIBUTE_UNUSED) noexcept { + const HeapData& data ABSL_ATTRIBUTE_UNUSED) { return false; } diff --git a/base/memory.h b/base/memory.h index 835a0bf06..eeb5f5e59 100644 --- a/base/memory.h +++ b/base/memory.h @@ -53,15 +53,14 @@ class UniqueRef final { UniqueRef(const UniqueRef&) = delete; - UniqueRef(UniqueRef&& other) noexcept - : ref_(other.ref_), owned_(other.owned_) { + UniqueRef(UniqueRef&& other) : ref_(other.ref_), owned_(other.owned_) { other.ref_ = nullptr; other.owned_ = false; } template >> - UniqueRef(UniqueRef&& other) noexcept // NOLINT + UniqueRef(UniqueRef&& other) // NOLINT : ref_(other.ref_), owned_(other.owned_) { other.ref_ = nullptr; other.owned_ = false; @@ -79,7 +78,7 @@ class UniqueRef final { UniqueRef& operator=(const UniqueRef&) = delete; - UniqueRef& operator=(UniqueRef&& other) noexcept { + UniqueRef& operator=(UniqueRef&& other) { if (ABSL_PREDICT_TRUE(this != &other)) { if (ref_ != nullptr) { if (owned_) { @@ -98,7 +97,7 @@ class UniqueRef final { template >> - UniqueRef& operator=(UniqueRef&& other) noexcept { // NOLINT + UniqueRef& operator=(UniqueRef&& other) { // NOLINT if (ABSL_PREDICT_TRUE(this != &other)) { if (ref_ != nullptr) { if (owned_) { diff --git a/base/owner.h b/base/owner.h index 0abe125f7..e6f6f148e 100644 --- a/base/owner.h +++ b/base/owner.h @@ -42,19 +42,17 @@ class Owner { Owner() = delete; - Owner(const Owner& other) noexcept : owner_(other.owner_) { + Owner(const Owner& other) : owner_(other.owner_) { if (owner_ != nullptr) { metadata_type::Ref(*owner_); } } - Owner(Owner&& other) noexcept : owner_(other.owner_) { - other.owner_ = nullptr; - } + Owner(Owner&& other) : owner_(other.owner_) { other.owner_ = nullptr; } template >> - Owner(const Owner& other) noexcept : owner_(other.owner_) { // NOLINT + Owner(const Owner& other) : owner_(other.owner_) { // NOLINT if (owner_ != nullptr) { metadata_type::Ref(*owner_); } @@ -66,7 +64,7 @@ class Owner { other.owner_ = nullptr; } - Owner& operator=(const Owner& other) noexcept { + Owner& operator=(const Owner& other) { if (this != &other) { if (static_cast(other)) { metadata_type::Ref(*other.owner_); @@ -79,7 +77,7 @@ class Owner { return *this; } - Owner& operator=(Owner&& other) noexcept { + Owner& operator=(Owner&& other) { if (this != &other) { if (static_cast(*this)) { metadata_type::Unref(*owner_); @@ -92,7 +90,7 @@ class Owner { template >> - Owner& operator=(const Owner& other) noexcept { + Owner& operator=(const Owner& other) { if (this != &other) { if (static_cast(other)) { metadata_type::Ref(*other.owner_); diff --git a/base/types/list_type.h b/base/types/list_type.h index cea6c2870..f44dcd14e 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -109,7 +109,7 @@ class ModernListType final : public ListType, public HeapData { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const ModernListType& type) noexcept { + static bool IsDestructorSkippable(const ModernListType& type) { return Metadata::IsDestructorSkippable(*type.element()); } diff --git a/base/types/map_type.h b/base/types/map_type.h index 792661ed3..9ce775a40 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -115,7 +115,7 @@ class ModernMapType final : public MapType, public HeapData { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const ModernMapType& type) noexcept { + static bool IsDestructorSkippable(const ModernMapType& type) { return Metadata::IsDestructorSkippable(*type.key()) && Metadata::IsDestructorSkippable(*type.value()); } diff --git a/base/types/optional_type.h b/base/types/optional_type.h index c7a0953bd..5565160ee 100644 --- a/base/types/optional_type.h +++ b/base/types/optional_type.h @@ -57,7 +57,7 @@ class OptionalType final : public OpaqueType { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const OptionalType& type) noexcept { + static bool IsDestructorSkippable(const OptionalType& type) { return base_internal::Metadata::IsDestructorSkippable(*type.type()); } diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 8b10cdbba..7770e31c3 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -74,7 +74,7 @@ class EnumValue final : public Value, public base_internal::InlineData { base_internal::kStoredInline | (static_cast(kKind) << base_internal::kKindShift); - static uintptr_t AdditionalMetadata(const EnumType& type) noexcept { + static uintptr_t AdditionalMetadata(const EnumType& type) { static_assert( std::is_base_of_v, "This logic relies on the fact that EnumValue is stored inline"); diff --git a/base/values/optional_value.h b/base/values/optional_value.h index 54e68b7d1..24b163166 100644 --- a/base/values/optional_value.h +++ b/base/values/optional_value.h @@ -98,7 +98,7 @@ class EmptyOptionalValue final : public OptionalValue { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const EmptyOptionalValue& value) noexcept { + static bool IsDestructorSkippable(const EmptyOptionalValue& value) { return base_internal::Metadata::IsDestructorSkippable(*value.type()); } @@ -117,7 +117,7 @@ class FullOptionalValue final : public OptionalValue { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const FullOptionalValue& value) noexcept { + static bool IsDestructorSkippable(const FullOptionalValue& value) { return base_internal::Metadata::IsDestructorSkippable(*value.type()) && base_internal::Metadata::IsDestructorSkippable(*value.value()); } diff --git a/extensions/protobuf/enum_type.h b/extensions/protobuf/enum_type.h index b4772e1fd..cd8147ebb 100644 --- a/extensions/protobuf/enum_type.h +++ b/extensions/protobuf/enum_type.h @@ -72,7 +72,7 @@ class ProtoEnumType final : public EnumType { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. static bool IsDestructorSkippable( - const ProtoEnumType& type ABSL_ATTRIBUTE_UNUSED) noexcept { + const ProtoEnumType& type ABSL_ATTRIBUTE_UNUSED) { // Our destructor is useless, we only hold pointers to protobuf-owned data. return true; } diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index bbfed1810..5edc2594a 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -85,7 +85,7 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. static bool IsDestructorSkippable( - const ProtoStructType& type ABSL_ATTRIBUTE_UNUSED) noexcept { + const ProtoStructType& type ABSL_ATTRIBUTE_UNUSED) { // Our destructor is useless, we only hold pointers to protobuf-owned data. return true; } From 142b3b55abccae8a3756ac39e57ed6a2fb833e2f Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 17 May 2023 19:34:54 +0000 Subject: [PATCH 287/303] Merge `:type` and `:value` into `:data` PiperOrigin-RevId: 532871600 --- base/BUILD | 114 ++++++++++++++++++------------------------ base/internal/BUILD | 4 +- base/internal/type.h | 1 - base/internal/value.h | 9 ++-- base/type_test.cc | 52 +++++++++---------- base/value_test.cc | 42 ++++++++-------- 6 files changed, 101 insertions(+), 121 deletions(-) diff --git a/base/BUILD b/base/BUILD index a050068c4..70e7e46e8 100644 --- a/base/BUILD +++ b/base/BUILD @@ -137,79 +137,40 @@ cc_test( ], ) +# Build target encompassing cel::Type, cel::Value, and their related classes. cc_library( - name = "type", + name = "data", srcs = [ "type.cc", "type_factory.cc", "type_manager.cc", "type_provider.cc", - ] + glob(["types/*.cc"]), + "value.cc", + "value_factory.cc", + ] + glob( + [ + "types/*.cc", + "values/*.cc", + ], + exclude = [ + "types/*_test.cc", + "values/*_test.cc", + ], + ), hdrs = [ "type.h", "type_factory.h", "type_manager.h", "type_provider.h", "type_registry.h", - ] + glob(["types/*.h"]), - deps = [ - ":handle", - ":kind", - ":memory", - "//base/internal:data", - "//base/internal:type", - "//internal:casts", - "//internal:no_destructor", - "//internal:overloaded", - "//internal:rtti", - "//internal:status_macros", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], -) - -cc_test( - name = "type_test", - srcs = [ - "type_factory_test.cc", - "type_provider_test.cc", - "type_test.cc", - ], - deps = [ - ":handle", - ":memory", - ":type", - ":value", - "//base/internal:memory_manager_testing", - "//internal:testing", - "@com_google_absl//absl/hash:hash_testing", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "value", - srcs = [ - "value.cc", - "value_factory.cc", - ] + glob( - ["values/*.cc"], - exclude = ["values/*_test.cc"], - ), - hdrs = [ "value.h", "value_factory.h", - ] + glob(["values/*.h"]), + ] + glob( + [ + "types/*.h", + "values/*.h", + ], + ), deps = [ ":attributes", ":function_result_set", @@ -217,19 +178,23 @@ cc_library( ":kind", ":memory", ":owner", - ":type", "//base/internal:data", "//base/internal:message_wrapper", + "//base/internal:type", "//base/internal:unknown_set", "//base/internal:value", "//internal:casts", "//internal:linked_hash_map", + "//internal:no_destructor", + "//internal:overloaded", "//internal:rtti", "//internal:status_macros", "//internal:strings", "//internal:time", "//internal:utf8", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", @@ -238,28 +203,37 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_googlesource_code_re2//:re2", ], ) cc_test( - name = "value_test", + name = "data_test", srcs = [ + "type_factory_test.cc", + "type_provider_test.cc", + "type_test.cc", "value_factory_test.cc", "value_test.cc", - ] + glob(["values/*_test.cc"]), + ] + glob([ + "types/*_test.cc", + "values/*_test.cc", + ]), deps = [ + ":data", + ":handle", ":memory", - ":type", - ":value", "//base/internal:memory_manager_testing", "//internal:benchmark", "//internal:strings", "//internal:testing", "//internal:time", + "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -267,6 +241,18 @@ cc_test( ], ) +alias( + name = "type", + actual = ":data", + deprecation = "Use :data instead.", +) + +alias( + name = "value", + actual = ":data", + deprecation = "Use :data instead.", +) + cc_library( name = "ast_internal", srcs = ["ast_internal.cc"], diff --git a/base/internal/BUILD b/base/internal/BUILD index f8bcbccef..e2d356430 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -85,7 +85,6 @@ cc_library( ], deps = [ ":data", - "//base:handle", "//base:kind", "//internal:rtti", ], @@ -110,9 +109,8 @@ cc_library( ], deps = [ ":data", - ":unknown_set", + ":type", "//base:handle", - "//base:type", "//internal:rtti", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/base/internal/type.h b/base/internal/type.h index c9ddbe09c..567c28d2a 100644 --- a/base/internal/type.h +++ b/base/internal/type.h @@ -19,7 +19,6 @@ #include -#include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" #include "internal/rtti.h" diff --git a/base/internal/value.h b/base/internal/value.h index 7a8621eed..78788f776 100644 --- a/base/internal/value.h +++ b/base/internal/value.h @@ -20,9 +20,7 @@ #include #include #include -#include #include -#include #include "absl/status/status.h" #include "absl/strings/cord.h" @@ -31,8 +29,7 @@ #include "absl/types/variant.h" #include "base/handle.h" #include "base/internal/data.h" -#include "base/internal/unknown_set.h" -#include "base/types/enum_type.h" +#include "base/internal/type.h" #include "internal/rtti.h" namespace cel { @@ -89,9 +86,9 @@ struct InlineValue final { absl::string_view string_value; uintptr_t owner; } string_value; - Handle type_value; + AnyType type_value; struct { - Handle type; + AnyType type; int64_t number; } enum_value; }; diff --git a/base/type_test.cc b/base/type_test.cc index fdbdc2453..bacc1557f 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -611,78 +611,78 @@ INSTANTIATE_TEST_SUITE_P(OptionalTypeTest, OptionalTypeTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeName); -class DebugStringTest : public TypeTest {}; +class TypeDebugStringTest : public TypeTest {}; -TEST_P(DebugStringTest, NullType) { +TEST_P(TypeDebugStringTest, NullType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetNullType()->DebugString(), "null_type"); } -TEST_P(DebugStringTest, ErrorType) { +TEST_P(TypeDebugStringTest, ErrorType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetErrorType()->DebugString(), "*error*"); } -TEST_P(DebugStringTest, DynType) { +TEST_P(TypeDebugStringTest, DynType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDynType()->DebugString(), "dyn"); } -TEST_P(DebugStringTest, AnyType) { +TEST_P(TypeDebugStringTest, AnyType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetAnyType()->DebugString(), "google.protobuf.Any"); } -TEST_P(DebugStringTest, BoolType) { +TEST_P(TypeDebugStringTest, BoolType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBoolType()->DebugString(), "bool"); } -TEST_P(DebugStringTest, IntType) { +TEST_P(TypeDebugStringTest, IntType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetIntType()->DebugString(), "int"); } -TEST_P(DebugStringTest, UintType) { +TEST_P(TypeDebugStringTest, UintType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetUintType()->DebugString(), "uint"); } -TEST_P(DebugStringTest, DoubleType) { +TEST_P(TypeDebugStringTest, DoubleType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDoubleType()->DebugString(), "double"); } -TEST_P(DebugStringTest, StringType) { +TEST_P(TypeDebugStringTest, StringType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetStringType()->DebugString(), "string"); } -TEST_P(DebugStringTest, BytesType) { +TEST_P(TypeDebugStringTest, BytesType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBytesType()->DebugString(), "bytes"); } -TEST_P(DebugStringTest, DurationType) { +TEST_P(TypeDebugStringTest, DurationType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDurationType()->DebugString(), "google.protobuf.Duration"); } -TEST_P(DebugStringTest, TimestampType) { +TEST_P(TypeDebugStringTest, TimestampType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetTimestampType()->DebugString(), "google.protobuf.Timestamp"); } -TEST_P(DebugStringTest, EnumType) { +TEST_P(TypeDebugStringTest, EnumType) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); EXPECT_EQ(enum_type->DebugString(), "test_enum.TestEnum"); } -TEST_P(DebugStringTest, StructType) { +TEST_P(TypeDebugStringTest, StructType) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ASSERT_OK_AND_ASSIGN( @@ -691,14 +691,14 @@ TEST_P(DebugStringTest, StructType) { EXPECT_EQ(struct_type->DebugString(), "test_struct.TestStruct"); } -TEST_P(DebugStringTest, ListType) { +TEST_P(TypeDebugStringTest, ListType) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetBoolType())); EXPECT_EQ(list_type->DebugString(), "list(bool)"); } -TEST_P(DebugStringTest, MapType) { +TEST_P(TypeDebugStringTest, MapType) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), @@ -706,55 +706,55 @@ TEST_P(DebugStringTest, MapType) { EXPECT_EQ(map_type->DebugString(), "map(string, bool)"); } -TEST_P(DebugStringTest, TypeType) { +TEST_P(TypeDebugStringTest, TypeType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetTypeType()->DebugString(), "type"); } -TEST_P(DebugStringTest, OptionalType) { +TEST_P(TypeDebugStringTest, OptionalType) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto optional_type, type_factory.CreateOptionalType( type_factory.GetStringType())); EXPECT_EQ(optional_type->DebugString(), "optional"); } -TEST_P(DebugStringTest, BoolWrapperType) { +TEST_P(TypeDebugStringTest, BoolWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBoolWrapperType()->DebugString(), "google.protobuf.BoolValue"); } -TEST_P(DebugStringTest, BytesWrapperType) { +TEST_P(TypeDebugStringTest, BytesWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBytesWrapperType()->DebugString(), "google.protobuf.BytesValue"); } -TEST_P(DebugStringTest, DoubleWrapperType) { +TEST_P(TypeDebugStringTest, DoubleWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDoubleWrapperType()->DebugString(), "google.protobuf.DoubleValue"); } -TEST_P(DebugStringTest, IntWrapperType) { +TEST_P(TypeDebugStringTest, IntWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetIntWrapperType()->DebugString(), "google.protobuf.Int64Value"); } -TEST_P(DebugStringTest, StringWrapperType) { +TEST_P(TypeDebugStringTest, StringWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetStringWrapperType()->DebugString(), "google.protobuf.StringValue"); } -TEST_P(DebugStringTest, UintWrapperType) { +TEST_P(TypeDebugStringTest, UintWrapperType) { TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetUintWrapperType()->DebugString(), "google.protobuf.UInt64Value"); } -INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, +INSTANTIATE_TEST_SUITE_P(TypeDebugStringTest, TypeDebugStringTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeName); diff --git a/base/value_test.cc b/base/value_test.cc index 24ae2f316..5a8874431 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -629,16 +629,16 @@ TEST_P(ValueTest, Swap) { EXPECT_EQ(rhs, value_factory.CreateIntValue(0)); } -using DebugStringTest = ValueTest; +using ValueDebugStringTest = ValueTest; -TEST_P(DebugStringTest, NullValue) { +TEST_P(ValueDebugStringTest, NullValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.GetNullValue()->DebugString(), "null"); } -TEST_P(DebugStringTest, BoolValue) { +TEST_P(ValueDebugStringTest, BoolValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -646,7 +646,7 @@ TEST_P(DebugStringTest, BoolValue) { EXPECT_EQ(value_factory.CreateBoolValue(true)->DebugString(), "true"); } -TEST_P(DebugStringTest, IntValue) { +TEST_P(ValueDebugStringTest, IntValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -661,7 +661,7 @@ TEST_P(DebugStringTest, IntValue) { "9223372036854775807"); } -TEST_P(DebugStringTest, UintValue) { +TEST_P(ValueDebugStringTest, UintValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -672,7 +672,7 @@ TEST_P(DebugStringTest, UintValue) { "18446744073709551615u"); } -TEST_P(DebugStringTest, DoubleValue) { +TEST_P(ValueDebugStringTest, DoubleValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -707,7 +707,7 @@ TEST_P(DebugStringTest, DoubleValue) { "-infinity"); } -TEST_P(DebugStringTest, DurationValue) { +TEST_P(ValueDebugStringTest, DurationValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -715,7 +715,7 @@ TEST_P(DebugStringTest, DurationValue) { internal::FormatDuration(absl::ZeroDuration()).value()); } -TEST_P(DebugStringTest, TimestampValue) { +TEST_P(ValueDebugStringTest, TimestampValue) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -723,7 +723,7 @@ TEST_P(DebugStringTest, TimestampValue) { internal::FormatTimestamp(absl::UnixEpoch()).value()); } -INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, +INSTANTIATE_TEST_SUITE_P(ValueDebugStringTest, ValueDebugStringTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeTupleName); @@ -2034,9 +2034,9 @@ TEST_P(ValueTest, Enum) { EXPECT_NE(two_value, one_value); } -using EnumTypeTest = ValueTest; +using EnumValueTest = ValueTest; -TEST_P(EnumTypeTest, NewInstance) { +TEST_P(EnumValueTest, NewInstance) { TypeFactory type_factory(memory_manager()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -2059,7 +2059,16 @@ TEST_P(EnumTypeTest, NewInstance) { StatusIs(absl::StatusCode::kNotFound)); } -INSTANTIATE_TEST_SUITE_P(EnumTypeTest, EnumTypeTest, +TEST_P(EnumValueTest, UnknownConstantDebugString) { + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + EXPECT_EQ(EnumValue::DebugString(*enum_type, 3), "test_enum.TestEnum(3)"); +} + +INSTANTIATE_TEST_SUITE_P(EnumValueTest, EnumValueTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeTupleName); @@ -2457,15 +2466,6 @@ TEST(TypeValue, SkippableDestructor) { EXPECT_TRUE(base_internal::Metadata::IsDestructorSkippable(*type_value)); } -TEST(EnumValueTest, UnknownConstantDebugString) { - TypeFactory type_factory(MemoryManager::Global()); - TypeManager type_manager(type_factory, TypeProvider::Builtin()); - ValueFactory value_factory(type_manager); - ASSERT_OK_AND_ASSIGN(auto enum_type, - type_factory.CreateEnumType()); - EXPECT_EQ(EnumValue::DebugString(*enum_type, 3), "test_enum.TestEnum(3)"); -} - Handle DefaultNullValue(ValueFactory& value_factory) { return value_factory.GetNullValue(); } From 831f5fea088c9d4a2d4d34667ab565bf66f0c8b1 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 18 May 2023 13:15:47 +0000 Subject: [PATCH 288/303] Workaround GCC friendship bugs PiperOrigin-RevId: 533103982 --- base/internal/memory_manager.h | 3 +-- base/memory.h | 14 +++++++++++++- base/types/list_type.h | 5 +++-- base/types/map_type.h | 7 ++++--- base/types/optional_type.h | 5 +++-- base/values/optional_value.h | 11 ++++++----- extensions/protobuf/BUILD | 3 ++- extensions/protobuf/enum_type.h | 4 ++-- extensions/protobuf/struct_type.h | 4 ++-- 9 files changed, 36 insertions(+), 20 deletions(-) diff --git a/base/internal/memory_manager.h b/base/internal/memory_manager.h index c9bf1e118..07a1dd72b 100644 --- a/base/internal/memory_manager.h +++ b/base/internal/memory_manager.h @@ -32,8 +32,7 @@ struct HasIsDestructorSkippable : std::false_type {}; template struct HasIsDestructorSkippable< - T, - std::void_t()))>> + T, std::void_t().IsDestructorSkippable())>> : std::true_type {}; } // namespace cel::base_internal diff --git a/base/memory.h b/base/memory.h index eeb5f5e59..7a3745c28 100644 --- a/base/memory.h +++ b/base/memory.h @@ -185,7 +185,7 @@ class MemoryManager { T(std::forward(args)...); if constexpr (!std::is_trivially_destructible_v) { if constexpr (base_internal::HasIsDestructorSkippable::value) { - if (!T::IsDestructorSkippable(*pointer)) { + if (!pointer->IsDestructorSkippable()) { OwnDestructor(pointer, &base_internal::MemoryManagerDestructor::Destruct); } @@ -357,6 +357,18 @@ class Allocator { bool allocation_only_; }; +// GCC before 12 has buggy friendship. Instead of calculating friendship at the +// point of evaluation it does so at the point where it is written. This macro +// ensures compatibility by friending both so IsDestructorSkippable works +// correctly. +#define CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() \ + private: \ + friend class ::cel::MemoryManager; \ + template \ + friend struct ::cel::base_internal::HasIsDestructorSkippable; \ + \ + bool IsDestructorSkippable() const + namespace base_internal { template diff --git a/base/types/list_type.h b/base/types/list_type.h index f44dcd14e..530898074 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -24,6 +24,7 @@ #include "absl/types/span.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/type.h" namespace cel { @@ -109,8 +110,8 @@ class ModernListType final : public ListType, public HeapData { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const ModernListType& type) { - return Metadata::IsDestructorSkippable(*type.element()); + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { + return Metadata::IsDestructorSkippable(*element()); } explicit ModernListType(Handle element); diff --git a/base/types/map_type.h b/base/types/map_type.h index 9ce775a40..083f6f82d 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -24,6 +24,7 @@ #include "absl/types/span.h" #include "base/internal/data.h" #include "base/kind.h" +#include "base/memory.h" #include "base/type.h" namespace cel { @@ -115,9 +116,9 @@ class ModernMapType final : public MapType, public HeapData { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const ModernMapType& type) { - return Metadata::IsDestructorSkippable(*type.key()) && - Metadata::IsDestructorSkippable(*type.value()); + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { + return Metadata::IsDestructorSkippable(*key()) && + Metadata::IsDestructorSkippable(*value()); } explicit ModernMapType(Handle key, Handle value); diff --git a/base/types/optional_type.h b/base/types/optional_type.h index 5565160ee..b949c1957 100644 --- a/base/types/optional_type.h +++ b/base/types/optional_type.h @@ -21,6 +21,7 @@ #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "base/memory.h" #include "base/type.h" #include "base/types/opaque_type.h" #include "internal/rtti.h" @@ -57,8 +58,8 @@ class OptionalType final : public OpaqueType { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const OptionalType& type) { - return base_internal::Metadata::IsDestructorSkippable(*type.type()); + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { + return base_internal::Metadata::IsDestructorSkippable(*type()); } explicit OptionalType(Handle type) : type_(std::move(type)) {} diff --git a/base/values/optional_value.h b/base/values/optional_value.h index 24b163166..1e038a6e5 100644 --- a/base/values/optional_value.h +++ b/base/values/optional_value.h @@ -21,6 +21,7 @@ #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "base/handle.h" +#include "base/memory.h" #include "base/type.h" #include "base/types/optional_type.h" #include "base/value.h" @@ -98,8 +99,8 @@ class EmptyOptionalValue final : public OptionalValue { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const EmptyOptionalValue& value) { - return base_internal::Metadata::IsDestructorSkippable(*value.type()); + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { + return base_internal::Metadata::IsDestructorSkippable(*type()); } explicit EmptyOptionalValue(Handle type) @@ -117,9 +118,9 @@ class FullOptionalValue final : public OptionalValue { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable(const FullOptionalValue& value) { - return base_internal::Metadata::IsDestructorSkippable(*value.type()) && - base_internal::Metadata::IsDestructorSkippable(*value.value()); + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { + return base_internal::Metadata::IsDestructorSkippable(*type()) && + base_internal::Metadata::IsDestructorSkippable(*value()); } FullOptionalValue(Handle type, Handle value) diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 7fe23ec5b..cb68a61e5 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -95,8 +95,9 @@ cc_library( "type_provider.h", ], deps = [ + "//base:data", "//base:handle", - "//base:type", + "//base:memory", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:die_if_null", diff --git a/extensions/protobuf/enum_type.h b/extensions/protobuf/enum_type.h index cd8147ebb..cd9da1ee2 100644 --- a/extensions/protobuf/enum_type.h +++ b/extensions/protobuf/enum_type.h @@ -17,6 +17,7 @@ #include "absl/base/attributes.h" #include "absl/log/die_if_null.h" +#include "base/memory.h" #include "base/type.h" #include "base/type_manager.h" #include "base/types/enum_type.h" @@ -71,8 +72,7 @@ class ProtoEnumType final : public EnumType { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable( - const ProtoEnumType& type ABSL_ATTRIBUTE_UNUSED) { + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { // Our destructor is useless, we only hold pointers to protobuf-owned data. return true; } diff --git a/extensions/protobuf/struct_type.h b/extensions/protobuf/struct_type.h index 5edc2594a..a6da09cd0 100644 --- a/extensions/protobuf/struct_type.h +++ b/extensions/protobuf/struct_type.h @@ -22,6 +22,7 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/handle.h" +#include "base/memory.h" #include "base/type.h" #include "base/type_manager.h" #include "base/types/struct_type.h" @@ -84,8 +85,7 @@ class ProtoStructType final : public CEL_STRUCT_TYPE_CLASS { // Called by Arena-based memory managers to determine whether we actually need // our destructor called. - static bool IsDestructorSkippable( - const ProtoStructType& type ABSL_ATTRIBUTE_UNUSED) { + CEL_INTERNAL_IS_DESTRUCTOR_SKIPPABLE() { // Our destructor is useless, we only hold pointers to protobuf-owned data. return true; } From b35780ee26d22366a71985bef8cea8324ed08092 Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 18 May 2023 17:02:45 +0000 Subject: [PATCH 289/303] Use ABSL_LOG macros over LOG macros LOG causes macro pollution and conflicts with ProxyWasm SDK (see https://github.com/proxy-wasm/proxy-wasm-cpp-sdk/issues/154). PiperOrigin-RevId: 533164114 --- eval/compiler/flat_expr_builder_test.cc | 2 +- eval/internal/BUILD | 1 - eval/internal/interop.cc | 1 - eval/public/BUILD | 10 +++++----- eval/public/ast_rewrite.cc | 4 ++-- eval/public/ast_rewrite_native.cc | 4 ++-- eval/public/ast_traverse.cc | 4 ++-- eval/public/ast_traverse_native.cc | 4 ++-- eval/public/builtin_func_test.cc | 2 +- eval/public/cel_value.h | 11 ++++++----- eval/public/structs/cel_proto_lite_wrap_util.cc | 6 +++--- eval/public/structs/cel_proto_wrap_util.cc | 6 +++--- eval/public/structs/field_access_impl_test.cc | 12 ++++++------ eval/public/structs/proto_message_type_adapter.cc | 8 ++++---- .../structs/protobuf_descriptor_type_provider.cc | 2 +- extensions/protobuf/type_provider.cc | 2 +- 16 files changed, 39 insertions(+), 40 deletions(-) diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index e49f92fe9..3a52b73ac 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1973,7 +1973,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { std::pair CreateTestMessage( const google::protobuf::DescriptorPool& descriptor_pool, google::protobuf::MessageFactory& message_factory, absl::string_view name) { - const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(std::string(name)); + const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(name); const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); google::protobuf::Message* message = message_prototype->New(); const google::protobuf::Reflection* refl = message->GetReflection(); diff --git a/eval/internal/BUILD b/eval/internal/BUILD index 0abc3ca8e..c91efc0ab 100644 --- a/eval/internal/BUILD +++ b/eval/internal/BUILD @@ -35,7 +35,6 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index c2ae7eb51..c3fb8d082 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -21,7 +21,6 @@ #include "google/protobuf/arena.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/eval/public/BUILD b/eval/public/BUILD index 85c051fac..65104bc69 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -83,7 +83,7 @@ cc_library( "//internal:status_macros", "//internal:utf8", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -633,7 +633,7 @@ cc_library( deps = [ ":ast_visitor", ":source_position", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -651,7 +651,7 @@ cc_library( ":ast_visitor_native", ":source_position_native", "//base:ast_internal", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:variant", ], ) @@ -835,7 +835,7 @@ cc_library( deps = [ ":ast_visitor", ":source_position", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -870,7 +870,7 @@ cc_library( deps = [ ":ast_visitor_native", ":source_position_native", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index c509a3a80..1d4f09393 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -18,7 +18,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/log/log.h" +#include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" @@ -196,7 +196,7 @@ struct PostVisitor { case Expr::EXPR_KIND_NOT_SET: break; default: - LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); diff --git a/eval/public/ast_rewrite_native.cc b/eval/public/ast_rewrite_native.cc index 3c006d5ab..89248cd3d 100644 --- a/eval/public/ast_rewrite_native.cc +++ b/eval/public/ast_rewrite_native.cc @@ -17,7 +17,7 @@ #include #include -#include "absl/log/log.h" +#include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor_native.h" #include "eval/public/source_position_native.h" @@ -199,7 +199,7 @@ struct PostVisitor { visitor->PostVisitComprehension(&comprehension, expr, position); } void operator()(absl::monostate) { - LOG(ERROR) << "Unsupported Expr kind"; + ABSL_LOG(ERROR) << "Unsupported Expr kind"; } } handler{visitor, record.expr, &position}; absl::visit(handler, record.expr->expr_kind()); diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index 96730841e..ce1a66202 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -17,7 +17,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/log/log.h" +#include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" @@ -197,7 +197,7 @@ struct PostVisitor { &position); break; default: - LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); + ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); diff --git a/eval/public/ast_traverse_native.cc b/eval/public/ast_traverse_native.cc index 97df453b7..c156a3ee8 100644 --- a/eval/public/ast_traverse_native.cc +++ b/eval/public/ast_traverse_native.cc @@ -16,7 +16,7 @@ #include -#include "absl/log/log.h" +#include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "base/ast_internal.h" #include "eval/public/ast_visitor_native.h" @@ -174,7 +174,7 @@ struct PostVisitor { &position); } void operator()(absl::monostate) { - LOG(ERROR) << "Unsupported Expr kind"; + ABSL_LOG(ERROR) << "Unsupported Expr kind"; } } handler{visitor, record.expr, SourcePosition(expr->id(), record.source_info)}; diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 96f6f5044..40c4b702e 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -69,7 +69,7 @@ class BuiltinsTest : public ::testing::Test { Expr expr; SourceInfo source_info; auto call = expr.mutable_call_expr(); - call->set_function(operation.data(), operation.size()); + call->set_function(operation); if (target.has_value()) { std::string param_name = "target"; diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 031585652..6bb60beb6 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -25,7 +25,7 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" -#include "absl/log/log.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -489,7 +489,8 @@ class CelValue { // Crashes with a null pointer error. static void CrashNullPointer(Type type) ABSL_ATTRIBUTE_COLD { - LOG(FATAL) << "Null pointer supplied for " << TypeName(type); // Crash ok + ABSL_LOG(FATAL) << "Null pointer supplied for " + << TypeName(type); // Crash ok } // Null pointer checker for pointer-based types. @@ -502,9 +503,9 @@ class CelValue { // Crashes with a type mismatch error. static void CrashTypeMismatch(Type requested_type, Type actual_type) ABSL_ATTRIBUTE_COLD { - LOG(FATAL) << "Type mismatch" // Crash ok - << ": expected " << TypeName(requested_type) // Crash ok - << ", encountered " << TypeName(actual_type); // Crash ok + ABSL_LOG(FATAL) << "Type mismatch" // Crash ok + << ": expected " << TypeName(requested_type) // Crash ok + << ", encountered " << TypeName(actual_type); // Crash ok } // Gets value of type specified diff --git a/eval/public/structs/cel_proto_lite_wrap_util.cc b/eval/public/structs/cel_proto_lite_wrap_util.cc index f99306bcd..4cb21e576 100644 --- a/eval/public/structs/cel_proto_lite_wrap_util.cc +++ b/eval/public/structs/cel_proto_lite_wrap_util.cc @@ -698,7 +698,7 @@ absl::StatusOr CreateMessageFromValue( if (wrapper == nullptr) { wrapper = google::protobuf::Arena::CreateMessage(arena); } - wrapper->set_value(view_val.value().data(), view_val.value().size()); + wrapper->set_value(view_val.value()); return wrapper; } @@ -782,7 +782,7 @@ absl::StatusOr CreateMessageFromValue( if (wrapper == nullptr) { wrapper = google::protobuf::Arena::CreateMessage(arena); } - wrapper->set_value(view_val.value().data(), view_val.value().size()); + wrapper->set_value(view_val.value()); return wrapper; } @@ -950,7 +950,7 @@ absl::StatusOr CreateMessageFromValue( case CelValue::Type::kString: { CelValue::StringHolder val; if (cel_value.GetValue(&val)) { - wrapper->set_string_value(val.value().data(), val.value().size()); + wrapper->set_string_value(val.value()); } } break; case CelValue::Type::kTimestamp: { diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index 1f077a9ab..9df9c0099 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -468,7 +468,7 @@ google::protobuf::Message* MessageFromValue( if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(view_val.value().data(), view_val.value().size()); + wrapper->set_value(view_val.value()); return wrapper; } @@ -536,7 +536,7 @@ google::protobuf::Message* MessageFromValue( if (!value.GetValue(&view_val)) { return nullptr; } - wrapper->set_value(view_val.value().data(), view_val.value().size()); + wrapper->set_value(view_val.value()); return wrapper; } @@ -682,7 +682,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json, case CelValue::Type::kString: { CelValue::StringHolder val; if (value.GetValue(&val)) { - json->set_string_value(val.value().data(), val.value().size()); + json->set_string_value(val.value()); return json; } } break; diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc index d5f259127..86b357803 100644 --- a/eval/public/structs/field_access_impl_test.cc +++ b/eval/public/structs/field_access_impl_test.cc @@ -184,14 +184,14 @@ class SingleFieldTest : public testing::TestWithParam { TEST_P(SingleFieldTest, Getter) { TestAllTypes test_message; ASSERT_TRUE( - google::protobuf::TextFormat::ParseFromString(std::string(message_textproto()), &test_message)); + google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromSingleField( &test_message, - test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), + test_message.GetDescriptor()->FindFieldByName(field_name()), ProtoWrapperTypeOptions::kUnsetProtoDefault, &CelProtoWrapper::InternalWrapMessage, &arena)); @@ -204,7 +204,7 @@ TEST_P(SingleFieldTest, Setter) { google::protobuf::Arena arena; ASSERT_OK(SetValueToSingleField( - to_set, test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), + to_set, test_message.GetDescriptor()->FindFieldByName(field_name()), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); @@ -361,14 +361,14 @@ class RepeatedFieldTest : public testing::TestWithParam { TEST_P(RepeatedFieldTest, GetFirstElem) { TestAllTypes test_message; ASSERT_TRUE( - google::protobuf::TextFormat::ParseFromString(std::string(message_textproto()), &test_message)); + google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromRepeatedField( &test_message, - test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), 0, + test_message.GetDescriptor()->FindFieldByName(field_name()), 0, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); @@ -380,7 +380,7 @@ TEST_P(RepeatedFieldTest, AppendElem) { google::protobuf::Arena arena; ASSERT_OK(AddValueToRepeatedField( - to_add, test_message.GetDescriptor()->FindFieldByName(std::string(field_name())), + to_add, test_message.GetDescriptor()->FindFieldByName(field_name()), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index d5ecfe25c..42c7f10d9 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -86,7 +86,7 @@ absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, absl::string_view field_name) { ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); - const FieldDescriptor* field_desc = descriptor->FindFieldByName(std::string(field_name)); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); if (field_desc == nullptr && reflection != nullptr) { // Search to see whether the field name is referring to an extension. field_desc = reflection->FindKnownExtensionByName(field_name); @@ -122,7 +122,7 @@ absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, cel::MemoryManager& memory_manager) { ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); - const FieldDescriptor* field_desc = descriptor->FindFieldByName(std::string(field_name)); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); if (field_desc == nullptr && reflection != nullptr) { std::string ext_name(field_name); field_desc = reflection->FindKnownExtensionByName(ext_name); @@ -332,7 +332,7 @@ ProtoMessageTypeAdapter::NewInstance(cel::MemoryManager& memory_manager) const { } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { - return descriptor_->FindFieldByName(std::string(field_name)) != nullptr; + return descriptor_->FindFieldByName(field_name) != nullptr; } absl::StatusOr ProtoMessageTypeAdapter::HasField( @@ -365,7 +365,7 @@ absl::Status ProtoMessageTypeAdapter::SetField( UnwrapMessage(instance, "SetField")); const google::protobuf::FieldDescriptor* field_descriptor = - descriptor_->FindFieldByName(std::string(field_name)); + descriptor_->FindFieldByName(field_name); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(field_descriptor != nullptr, field_name, "not found")); diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index a2928aed3..03313d733 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -53,7 +53,7 @@ ProtobufDescriptorProvider::ProvideLegacyTypeInfo( std::unique_ptr ProtobufDescriptorProvider::GetType( absl::string_view name) const { const google::protobuf::Descriptor* descriptor = - descriptor_pool_->FindMessageTypeByName(std::string(name)); + descriptor_pool_->FindMessageTypeByName(name); if (descriptor == nullptr) { return nullptr; } diff --git a/extensions/protobuf/type_provider.cc b/extensions/protobuf/type_provider.cc index e639588ed..b09a8247d 100644 --- a/extensions/protobuf/type_provider.cc +++ b/extensions/protobuf/type_provider.cc @@ -23,7 +23,7 @@ namespace cel::extensions { absl::StatusOr>> ProtoTypeProvider::ProvideType( TypeFactory& type_factory, absl::string_view name) const { { - const auto* desc = pool_->FindMessageTypeByName(std::string(name)); + const auto* desc = pool_->FindMessageTypeByName(name); if (desc != nullptr) { return type_factory.CreateStructType(desc, factory_); } From 1032d847e7256121e8e2e76151a446161afed495 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 18 May 2023 18:15:20 +0000 Subject: [PATCH 290/303] Refactor `ListValueBuilder` and `MapValueBuilder` ahead of introducing `NewValueBuilder` to `ListType` and `MapType` PiperOrigin-RevId: 533189606 --- base/BUILD | 1 + base/values/list_value_builder.h | 142 ++++--- base/values/map_value_builder.h | 547 ++++++++++++++------------ base/values/map_value_builder_test.cc | 367 +++++++++++++++-- 4 files changed, 723 insertions(+), 334 deletions(-) diff --git a/base/BUILD b/base/BUILD index 70e7e46e8..f3065cb83 100644 --- a/base/BUILD +++ b/base/BUILD @@ -208,6 +208,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/utility", "@com_googlesource_code_re2//:re2", ], ) diff --git a/base/values/list_value_builder.h b/base/values/list_value_builder.h index ae4d9d7d6..768ce12e6 100644 --- a/base/values/list_value_builder.h +++ b/base/values/list_value_builder.h @@ -23,9 +23,12 @@ #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" #include "base/memory.h" #include "base/value_factory.h" #include "base/values/list_value.h" +#include "internal/overloaded.h" namespace cel { @@ -75,6 +78,53 @@ class ListValueBuilder; namespace base_internal { +// ComposableListType is a variant which represents either the ListType or the +// element Type for creating a ListType. +template +using ComposableListType = absl::variant, Handle>; + +// Create a ListType from ComposableListType. +template +absl::StatusOr> ComposeListType( + ValueFactory& value_factory, ComposableListType&& composable) { + return absl::visit( + internal::Overloaded{ + [&value_factory]( + Handle&& element) -> absl::StatusOr> { + return value_factory.type_factory().CreateListType( + std::move(element)); + }, + [](Handle&& list) -> absl::StatusOr> { + return std::move(list); + }, + }, + std::move(composable)); +} + +template +std::string ComposeListValueDebugString(const List& list, + const DebugStringer& debug_stringer) { + std::string out; + out.push_back('['); + auto current = list.begin(); + if (current != list.end()) { + out.append(debug_stringer(*current)); + ++current; + for (; current != list.end(); ++current) { + out.append(", "); + out.append(debug_stringer(*current)); + } + } + out.push_back(']'); + return out; +} + +struct ComposedListType { + explicit ComposedListType() = default; +}; + +inline constexpr ComposedListType kComposedListType{}; + // Implementation of ListValueBuilder. Specialized to store some value types as // C++ primitives, avoiding Handle overhead. Anything that does not have a C++ // primitive is stored as Handle. @@ -92,22 +142,22 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, Handle::type_type> type) : ListValueBuilderInterface(value_factory), - type_(std::move(type)), + type_(absl::in_place_type::type_type>>, + std::move(type)), + storage_(Allocator>{value_factory.memory_manager()}) {} + + ListValueBuilderImpl( + ComposedListType, + ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : ListValueBuilderInterface(value_factory), + type_(absl::in_place_type>, std::move(type)), storage_(Allocator>{value_factory.memory_manager()}) {} std::string DebugString() const override { - size_t count = size(); - std::string out; - out.push_back('['); - if (count != 0) { - out.append(storage_[0]->DebugString()); - for (size_t index = 1; index < count; index++) { - out.append(", "); - out.append(storage_[index]->DebugString()); - } - } - out.push_back(']'); - return out; + return ComposeListValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }); } absl::Status Add(Handle value) override { @@ -127,14 +177,14 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { absl::StatusOr> Build() && override { CEL_ASSIGN_OR_RETURN(auto type, - value_factory().type_factory().CreateListType(type_)); + ComposeListType(value_factory(), std::move(type_))); return value_factory() .template CreateListValue( std::move(type), std::move(storage_)); } private: - Handle::type_type> type_; + ComposableListType::type_type> type_; std::vector, Allocator>> storage_; }; @@ -147,22 +197,21 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, Handle type) : ListValueBuilderInterface(value_factory), - type_(std::move(type)), + type_(absl::in_place_type>, std::move(type)), + storage_(Allocator>{value_factory.memory_manager()}) {} + + ListValueBuilderImpl( + ComposedListType, + ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : ListValueBuilderInterface(value_factory), + type_(absl::in_place_type>, std::move(type)), storage_(Allocator>{value_factory.memory_manager()}) {} std::string DebugString() const override { - size_t count = size(); - std::string out; - out.push_back('['); - if (count != 0) { - out.append(storage_[0]->DebugString()); - for (size_t index = 1; index < count; index++) { - out.append(", "); - out.append(storage_[index]->DebugString()); - } - } - out.push_back(']'); - return out; + return ComposeListValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }); } absl::Status Add(Handle value) override { @@ -178,14 +227,14 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { absl::StatusOr> Build() && override { CEL_ASSIGN_OR_RETURN(auto type, - value_factory().type_factory().CreateListType(type_)); + ComposeListType(value_factory(), std::move(type_))); return value_factory() .template CreateListValue( std::move(type), std::move(storage_)); } private: - Handle type_; + ComposableListType type_; std::vector, Allocator>> storage_; }; @@ -198,23 +247,22 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, Handle::type_type> type) : ListValueBuilderInterface(value_factory), - type_(std::move(type)), + type_(absl::in_place_type::type_type>>, + std::move(type)), + storage_(Allocator{value_factory.memory_manager()}) {} + + ListValueBuilderImpl( + ComposedListType, + ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : ListValueBuilderInterface(value_factory), + type_(absl::in_place_type>, std::move(type)), storage_(Allocator{value_factory.memory_manager()}) {} std::string DebugString() const override { - using value_traits = ValueTraits; - size_t count = size(); - std::string out; - out.push_back('['); - if (count != 0) { - out.append(value_traits::DebugString(storage_[0])); - for (size_t index = 1; index < count; index++) { - out.append(", "); - out.append(value_traits::DebugString(storage_[index])); - } - } - out.push_back(']'); - return out; + return ComposeListValueDebugString(storage_, [](const U& value) { + return ValueTraits::DebugString(value); + }); } absl::Status Add(Handle value) override { @@ -236,14 +284,14 @@ class ListValueBuilderImpl : public ListValueBuilderInterface { absl::StatusOr> Build() && override { CEL_ASSIGN_OR_RETURN(auto type, - value_factory().type_factory().CreateListType(type_)); + ComposeListType(value_factory(), std::move(type_))); return value_factory() .template CreateListValue>( std::move(type), std::move(storage_)); } private: - Handle::type_type> type_; + ComposableListType::type_type> type_; std::vector> storage_; }; @@ -257,6 +305,8 @@ class ListValueBuilder final using Impl = base_internal::ListValueBuilderImpl< T, typename base_internal::ValueTraits::underlying_type>; + static_assert(!std::is_same_v); + public: using Impl::Impl; }; diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index 931ad153d..79c06a7f7 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -22,11 +22,13 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/status/statusor.h" +#include "absl/types/variant.h" #include "base/memory.h" #include "base/value_factory.h" #include "base/values/list_value_builder.h" #include "base/values/map_value.h" #include "internal/linked_hash_map.h" +#include "internal/overloaded.h" #include "internal/status_macros.h" namespace cel { @@ -54,7 +56,7 @@ class MapValueBuilderInterface { // A combination of Insert and Update, where the entry is inserted if it // doesn't already exist or it is updated. Returns true if insertion occurred, // false otherwise. - virtual absl::StatusOr InsertOrUpdate(Handle key, + virtual absl::StatusOr InsertOrAssign(Handle key, Handle value) = 0; // Returns whether the given key has been inserted. @@ -159,6 +161,29 @@ struct MapKeyEqualer> { } }; +template +std::string ComposeMapValueDebugString( + const Map& map, const KeyDebugStringer& key_debug_stringer, + const ValueDebugStringer& value_debug_stringer) { + std::string out; + out.push_back('{'); + auto current = map.begin(); + if (current != map.end()) { + out.append(key_debug_stringer(current->first)); + out.append(": "); + out.append(value_debug_stringer(current->second)); + ++current; + for (; current != map.end(); ++current) { + out.append(", "); + out.append(key_debug_stringer(current->first)); + out.append(": "); + out.append(value_debug_stringer(current->second)); + } + } + out.push_back('}'); + return out; +} + // For MapValueBuilder we use a linked hash map to preserve insertion order. // This mimics protobuf and ensures some reproducibility, making testing easier. @@ -175,23 +200,10 @@ class DynamicMapValue final : public AbstractMapValue { : AbstractMapValue(std::move(type)), storage_(std::move(storage)) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const Handle& value) { return value->DebugString(); }); } size_t size() const override { return storage_.size(); } @@ -250,24 +262,12 @@ class StaticMapValue final : public AbstractMapValue { : AbstractMapValue(std::move(type)), storage_(std::move(storage)) {} std::string DebugString() const override { - using key_value_traits = ValueTraits; - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(key_value_traits::DebugString(current->first)); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(key_value_traits::DebugString(current->first)); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const underlying_key_type& value) { + return ValueTraits::DebugString(value); + }, + [](const Handle& value) { return value->DebugString(); }); } size_t size() const override { return storage_.size(); } @@ -290,8 +290,9 @@ class StaticMapValue final : public AbstractMapValue { absl::StatusOr> ListKeys( const ListKeysContext& context) const override { - ListValueBuilder keys(context.value_factory(), - type()->key().template As()); + ListValueBuilder keys( + context.value_factory(), + type()->key().template As::type_type>()); keys.reserve(size()); for (const auto& current : storage_) { CEL_RETURN_IF_ERROR(keys.Add(current.first)); @@ -323,24 +324,12 @@ class StaticMapValue final : public AbstractMapValue { : AbstractMapValue(std::move(type)), storage_(std::move(storage)) {} std::string DebugString() const override { - using value_value_traits = ValueTraits; - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(value_value_traits::DebugString(current->second)); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(value_value_traits::DebugString(current->second)); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const underlying_value_type& value) { + return ValueTraits::DebugString(value); + }); } size_t size() const override { return storage_.size(); } @@ -398,25 +387,14 @@ class StaticMapValue final : public AbstractMapValue { : AbstractMapValue(std::move(type)), storage_(std::move(storage)) {} std::string DebugString() const override { - using key_value_traits = ValueTraits; - using value_value_traits = ValueTraits; - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(key_value_traits::DebugString(current->first)); - out.append(": "); - out.append(value_value_traits::DebugString(current->second)); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(key_value_traits::DebugString(current->first)); - out.append(": "); - out.append(value_value_traits::DebugString(current->second)); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const underlying_key_type& value) { + return ValueTraits::DebugString(value); + }, + [](const underlying_value_type& value) { + return ValueTraits::DebugString(value); + }); } size_t size() const override { return storage_.size(); } @@ -457,6 +435,30 @@ class StaticMapValue final : public AbstractMapValue { hash_map_type storage_; }; +// ComposableMapType is a variant which represents either the MapType or the +// key and value Type for creating a MapType. +template +using ComposableMapType = + absl::variant, Handle>, Handle>; + +// Create a MapType from ComposableMapType. +template +absl::StatusOr> ComposeMapType( + ValueFactory& value_factory, ComposableMapType&& composable) { + return absl::visit( + internal::Overloaded{ + [&value_factory](std::pair, Handle>&& key_value) + -> absl::StatusOr> { + return value_factory.type_factory().CreateMapType( + std::move(key_value).first, std::move(key_value).second); + }, + [](Handle&& map) -> absl::StatusOr> { + return std::move(map); + }, + }, + std::move(composable)); +} + // Implementation of MapValueBuilder. Specialized to store some value types are // C++ primitives, avoiding Handle overhead. Anything that does not have a C++ // primitive is stored as Handle. @@ -475,29 +477,22 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { Handle::type_type> key, Handle::type_type> value) : MapValueBuilderInterface(value_factory), - key_(std::move(key)), - value_(std::move(value)), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator, Handle>>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), storage_(Allocator, Handle>>{ value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const Handle& value) { return value->DebugString(); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -523,12 +518,12 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key).As(), std::move(value).As()); + return InsertOrAssign(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { + absl::StatusOr InsertOrAssign(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -543,15 +538,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { bool empty() const override { return storage_.empty(); } absl::StatusOr> Build() && override { - CEL_ASSIGN_OR_RETURN( - auto type, value_factory().type_factory().CreateMapType(key_, value_)); + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); return value_factory().template CreateMapValue( std::move(type), std::move(storage_)); } private: - Handle::type_type> key_; - Handle::type_type> value_; + ComposableMapType::type_type, + typename ValueTraits::type_type> + type_; internal::LinkedHashMap< Handle, Handle, MapKeyHasher>, MapKeyEqualer>, @@ -577,23 +573,10 @@ class MapValueBuilderImpl value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const Handle& value) { return value->DebugString(); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -619,12 +602,12 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key).As(), std::move(value)); + return InsertOrAssign(std::move(key).As(), std::move(value)); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { + absl::StatusOr InsertOrAssign(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -673,23 +656,10 @@ class MapValueBuilderImpl value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const Handle& value) { return value->DebugString(); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -715,12 +685,12 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key), std::move(value).As()); + return InsertOrAssign(std::move(key), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { + absl::StatusOr InsertOrAssign(Handle key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -749,6 +719,103 @@ class MapValueBuilderImpl storage_; }; +// Specialization for key type being Value itself and value type has some C++ +// primitive representation. +template +class MapValueBuilderImpl + : public MapValueBuilderInterface { + public: + static_assert(std::is_base_of_v); + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle key, + Handle::type_type> value) + : MapValueBuilderInterface(value_factory), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator, Handle>>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), + storage_(Allocator, Handle>>{ + value_factory.memory_manager()}) {} + + std::string DebugString() const override { + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const UV& value) { return ValueTraits::DebugString(value); }); + } + + absl::StatusOr Insert(Handle key, Handle value) override { + return Insert(std::move(key), std::move(value).As()); + } + + absl::StatusOr Insert(Handle key, Handle value) { + return Insert(std::move(key), value->value()); + } + + absl::StatusOr Insert(Handle key, UV value) { + return storage_.insert(std::make_pair(std::move(key), std::move(value))) + .second; + } + + absl::StatusOr Update(const Handle& key, + Handle value) override { + return Update(key, std::move(value).As()); + } + + absl::StatusOr Update(const Handle& key, Handle value) { + return Update(key, value->value()); + } + + absl::StatusOr Update(const Handle& key, UV value) { + auto existing = storage_.find(key); + if (existing == storage_.end()) { + return false; + } + existing->second = std::move(value); + return true; + } + + absl::StatusOr InsertOrAssign(Handle key, + Handle value) override { + return InsertOrAssign(std::move(key), std::move(value).As()); + } + + absl::StatusOr InsertOrAssign(Handle key, Handle value) { + return InsertOrAssign(std::move(key), value->value()); + } + + absl::StatusOr InsertOrAssign(Handle key, UV value) { + return storage_.insert_or_assign(std::move(key), std::move(value)).second; + } + + bool Has(const Handle& key) const override { + return storage_.find(key) != storage_.end(); + } + + size_t size() const override { return storage_.size(); } + + bool empty() const override { return storage_.empty(); } + + absl::StatusOr> Build() && override { + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); + return value_factory().template CreateMapValue>( + std::move(type), std::move(storage_)); + } + + private: + ComposableMapType::type_type> type_; + internal::LinkedHashMap, UV, MapKeyHasher>, + MapKeyEqualer>, + Allocator, UV>>> + storage_; +}; + // Specialization for key type and value type being Value itself. template <> class MapValueBuilderImpl @@ -757,29 +824,22 @@ class MapValueBuilderImpl MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, Handle key, Handle value) : MapValueBuilderInterface(value_factory), - key_(std::move(key)), - value_(std::move(value)), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator, Handle>>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), storage_(Allocator, Handle>>{ value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const Handle& value) { return value->DebugString(); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -797,7 +857,7 @@ class MapValueBuilderImpl return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -811,15 +871,14 @@ class MapValueBuilderImpl bool empty() const override { return storage_.empty(); } absl::StatusOr> Build() && override { - CEL_ASSIGN_OR_RETURN( - auto type, value_factory().type_factory().CreateMapType(key_, value_)); + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); return value_factory().template CreateMapValue( std::move(type), std::move(storage_)); } private: - Handle key_; - Handle value_; + ComposableMapType type_; internal::LinkedHashMap< Handle, Handle, MapKeyHasher>, MapKeyEqualer>, @@ -840,29 +899,22 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { Handle::type_type> key, Handle::type_type> value) : MapValueBuilderInterface(value_factory), - key_(std::move(key)), - value_(std::move(value)), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator>>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), storage_(Allocator>>{ value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(ValueTraits::DebugString(current->first)); - out.append(": "); - out.append(current->second->DebugString()); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(ValueTraits::DebugString(current->first)); - out.append(": "); - out.append(current->second->DebugString()); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const UK& value) { return ValueTraits::DebugString(value); }, + [](const Handle& value) { return value->DebugString(); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -896,16 +948,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key).As(), std::move(value).As()); + return InsertOrAssign(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { - return InsertOrUpdate(key->value(), std::move(value)); + absl::StatusOr InsertOrAssign(Handle key, Handle value) { + return InsertOrAssign(key->value(), std::move(value)); } - absl::StatusOr InsertOrUpdate(UK key, Handle value) { + absl::StatusOr InsertOrAssign(UK key, Handle value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -920,15 +972,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { bool empty() const override { return storage_.empty(); } absl::StatusOr> Build() && override { - CEL_ASSIGN_OR_RETURN( - auto type, value_factory().type_factory().CreateMapType(key_, value_)); + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); return value_factory().template CreateMapValue>( std::move(type), std::move(storage_)); } private: - Handle::type_type> key_; - Handle::type_type> value_; + ComposableMapType::type_type, + typename ValueTraits::type_type> + type_; internal::LinkedHashMap, MapKeyHasher, MapKeyEqualer, Allocator>>> @@ -948,29 +1001,22 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { Handle::type_type> key, Handle::type_type> value) : MapValueBuilderInterface(value_factory), - key_(std::move(key)), - value_(std::move(value)), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator, UV>>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), storage_(Allocator, UV>>{ value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(current->first->DebugString()); - out.append(": "); - out.append(ValueTraits::DebugString(current->second)); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(current->first->DebugString()); - out.append(": "); - out.append(ValueTraits::DebugString(current->second)); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const Handle& value) { return value->DebugString(); }, + [](const UV& value) { return ValueTraits::DebugString(value); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -1004,16 +1050,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key).As(), std::move(value).As()); + return InsertOrAssign(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { - return InsertOrUpdate(std::move(key), value->value()); + absl::StatusOr InsertOrAssign(Handle key, Handle value) { + return InsertOrAssign(std::move(key), value->value()); } - absl::StatusOr InsertOrUpdate(Handle key, UV value) { + absl::StatusOr InsertOrAssign(Handle key, UV value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1028,15 +1074,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { bool empty() const override { return storage_.empty(); } absl::StatusOr> Build() && override { - CEL_ASSIGN_OR_RETURN( - auto type, value_factory().type_factory().CreateMapType(key_, value_)); + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); return value_factory().template CreateMapValue>( std::move(type), std::move(storage_)); } private: - Handle::type_type> key_; - Handle::type_type> value_; + ComposableMapType::type_type, + typename ValueTraits::type_type> + type_; internal::LinkedHashMap, UV, MapKeyHasher>, MapKeyEqualer>, Allocator, UV>>> @@ -1057,29 +1104,22 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { Handle::type_type> key, Handle::type_type> value) : MapValueBuilderInterface(value_factory), - key_(std::move(key)), - value_(std::move(value)), + type_(std::make_pair(std::move(key), std::move(value))), + storage_(Allocator>{ + value_factory.memory_manager()}) {} + + MapValueBuilderImpl(ABSL_ATTRIBUTE_LIFETIME_BOUND ValueFactory& value_factory, + Handle type) + : MapValueBuilderInterface(value_factory), + type_(std::move(type)), storage_(Allocator>{ value_factory.memory_manager()}) {} std::string DebugString() const override { - std::string out; - out.push_back('{'); - auto current = storage_.begin(); - if (current != storage_.end()) { - out.append(ValueTraits::DebugString(current->first)); - out.append(": "); - out.append(ValueTraits::DebugString(current->second)); - ++current; - for (; current != storage_.end(); ++current) { - out.append(", "); - out.append(ValueTraits::DebugString(current->first)); - out.append(": "); - out.append(ValueTraits::DebugString(current->second)); - } - } - out.push_back('}'); - return out; + return ComposeMapValueDebugString( + storage_, + [](const UK& value) { return ValueTraits::DebugString(value); }, + [](const UV& value) { return ValueTraits::DebugString(value); }); } absl::StatusOr Insert(Handle key, Handle value) override { @@ -1129,24 +1169,24 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { return true; } - absl::StatusOr InsertOrUpdate(Handle key, + absl::StatusOr InsertOrAssign(Handle key, Handle value) override { - return InsertOrUpdate(std::move(key).As(), std::move(value).As()); + return InsertOrAssign(std::move(key).As(), std::move(value).As()); } - absl::StatusOr InsertOrUpdate(Handle key, Handle value) { - return InsertOrUpdate(key->value(), value->value()); + absl::StatusOr InsertOrAssign(Handle key, Handle value) { + return InsertOrAssign(key->value(), value->value()); } - absl::StatusOr InsertOrUpdate(Handle key, UV value) { - return InsertOrUpdate(key->value(), std::move(value)); + absl::StatusOr InsertOrAssign(Handle key, UV value) { + return InsertOrAssign(key->value(), std::move(value)); } - absl::StatusOr InsertOrUpdate(UK key, Handle value) { - return InsertOrUpdate(std::move(key), value->value()); + absl::StatusOr InsertOrAssign(UK key, Handle value) { + return InsertOrAssign(std::move(key), value->value()); } - absl::StatusOr InsertOrUpdate(UK key, UV value) { + absl::StatusOr InsertOrAssign(UK key, UV value) { return storage_.insert_or_assign(std::move(key), std::move(value)).second; } @@ -1161,15 +1201,16 @@ class MapValueBuilderImpl : public MapValueBuilderInterface { bool empty() const override { return storage_.empty(); } absl::StatusOr> Build() && override { - CEL_ASSIGN_OR_RETURN( - auto type, value_factory().type_factory().CreateMapType(key_, value_)); + CEL_ASSIGN_OR_RETURN(auto type, + ComposeMapType(value_factory(), std::move(type_))); return value_factory().template CreateMapValue>( std::move(type), std::move(storage_)); } private: - Handle::type_type> key_; - Handle::type_type> value_; + ComposableMapType::type_type, + typename ValueTraits::type_type> + type_; internal::LinkedHashMap, MapKeyEqualer, Allocator>> storage_; diff --git a/base/values/map_value_builder_test.cc b/base/values/map_value_builder_test.cc index a325a1655..2d077b9e4 100644 --- a/base/values/map_value_builder_test.cc +++ b/base/values/map_value_builder_test.cc @@ -59,13 +59,13 @@ TEST(MapValueBuilder, UnspecializedUnspecialized) { IsOkAndHolds(IsTrue())); // lvalue, lvalue EXPECT_THAT(map_builder.Update(key, make_value()), IsOkAndHolds(IsTrue())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, value), + EXPECT_THAT(map_builder.InsertOrAssign(key, value), IsOkAndHolds(IsFalse())); // lvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, make_value()), + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), IsOkAndHolds(IsFalse())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), value), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), IsOkAndHolds(IsTrue())); // rvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), make_value("baz")), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value("baz")), IsOkAndHolds(IsFalse())); // rvalue, rvalue EXPECT_TRUE(map_builder.Has(key)); EXPECT_FALSE(map_builder.empty()); @@ -89,6 +89,11 @@ TEST(MapValueBuilder, UnspecializedUnspecialized) { EXPECT_EQ((*entry).As()->ToString(), "baz"); EXPECT_EQ(map->DebugString(), "{\"\": b\"\", \"foo\": b\"\", \"bar\": b\"baz\"}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); } TEST(MapValueBuilder, UnspecializedGeneric) { @@ -123,13 +128,13 @@ TEST(MapValueBuilder, UnspecializedGeneric) { IsOkAndHolds(IsTrue())); // lvalue, lvalue EXPECT_THAT(map_builder.Update(key, make_value()), IsOkAndHolds(IsTrue())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, value), + EXPECT_THAT(map_builder.InsertOrAssign(key, value), IsOkAndHolds(IsFalse())); // lvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, make_value()), + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), IsOkAndHolds(IsFalse())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), value), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), IsOkAndHolds(IsTrue())); // rvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), make_value("baz")), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value("baz")), IsOkAndHolds(IsFalse())); // rvalue, rvalue EXPECT_TRUE(map_builder.Has(key)); EXPECT_FALSE(map_builder.empty()); @@ -153,6 +158,11 @@ TEST(MapValueBuilder, UnspecializedGeneric) { EXPECT_EQ((*entry).As()->ToString(), "baz"); EXPECT_EQ(map->DebugString(), "{\"\": b\"\", \"foo\": b\"\", \"bar\": b\"baz\"}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); } TEST(MapValueBuilder, GenericUnspecialized) { @@ -187,13 +197,13 @@ TEST(MapValueBuilder, GenericUnspecialized) { IsOkAndHolds(IsTrue())); // lvalue, lvalue EXPECT_THAT(map_builder.Update(key, make_value()), IsOkAndHolds(IsTrue())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, value), + EXPECT_THAT(map_builder.InsertOrAssign(key, value), IsOkAndHolds(IsFalse())); // lvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, make_value()), + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), IsOkAndHolds(IsFalse())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), value), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), IsOkAndHolds(IsTrue())); // rvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), make_value("baz")), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value("baz")), IsOkAndHolds(IsFalse())); // rvalue, rvalue EXPECT_TRUE(map_builder.Has(key)); EXPECT_FALSE(map_builder.empty()); @@ -217,6 +227,11 @@ TEST(MapValueBuilder, GenericUnspecialized) { EXPECT_EQ((*entry).As()->ToString(), "baz"); EXPECT_EQ(map->DebugString(), "{\"\": b\"\", \"foo\": b\"\", \"bar\": b\"baz\"}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); } TEST(MapValueBuilder, GenericGeneric) { @@ -251,13 +266,13 @@ TEST(MapValueBuilder, GenericGeneric) { IsOkAndHolds(IsTrue())); // lvalue, lvalue EXPECT_THAT(map_builder.Update(key, make_value()), IsOkAndHolds(IsTrue())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, value), + EXPECT_THAT(map_builder.InsertOrAssign(key, value), IsOkAndHolds(IsFalse())); // lvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, make_value()), + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), IsOkAndHolds(IsFalse())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), value), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), IsOkAndHolds(IsTrue())); // rvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key("bar"), make_value("baz")), + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value("baz")), IsOkAndHolds(IsFalse())); // rvalue, rvalue EXPECT_TRUE(map_builder.Has(key)); EXPECT_FALSE(map_builder.empty()); @@ -281,13 +296,278 @@ TEST(MapValueBuilder, GenericGeneric) { EXPECT_EQ((*entry).As()->ToString(), "baz"); EXPECT_EQ(map->DebugString(), "{\"\": b\"\", \"foo\": b\"\", \"bar\": b\"baz\"}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); +} + +TEST(MapValueBuilder, UnspecializedSpecialized) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto map_builder = MapValueBuilder( + value_factory, type_factory.GetStringType(), type_factory.GetIntType()); + auto make_key = [&](absl::string_view value = + absl::string_view()) -> Handle { + return value_factory.CreateStringValue(value).value(); + }; + auto key = make_key(); + auto make_value = [&](int64_t value = 0) -> Handle { + return value_factory.CreateIntValue(value); + }; + auto value = make_value(); + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Insert(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(make_key("foo"), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.Insert(make_key("foo"), make_value()), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsTrue())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value(1)), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_TRUE(map_builder.Has(key)); + EXPECT_FALSE(map_builder.empty()); + EXPECT_EQ(map_builder.size(), 3); + EXPECT_EQ(map_builder.DebugString(), "{\"\": 0, \"foo\": 0, \"bar\": 1}"); + ASSERT_OK_AND_ASSIGN(auto map, std::move(map_builder).Build()); + EXPECT_FALSE(map->empty()); + EXPECT_EQ(map->size(), 3); + ASSERT_OK_AND_ASSIGN(auto entry, + map->Get(MapValue::GetContext(value_factory), key)); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 0); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key("foo"))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 0); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key("bar"))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 1); + EXPECT_EQ(map->DebugString(), "{\"\": 0, \"foo\": 0, \"bar\": 1}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); +} + +TEST(MapValueBuilder, GenericSpecialized) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto map_builder = MapValueBuilder( + value_factory, type_factory.GetStringType(), type_factory.GetIntType()); + auto make_key = [&](absl::string_view value = + absl::string_view()) -> Handle { + return value_factory.CreateStringValue(value).value(); + }; + auto key = make_key(); + auto make_value = [&](int64_t value = 0) -> Handle { + return value_factory.CreateIntValue(value); + }; + auto value = make_value(); + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Insert(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(make_key("foo"), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.Insert(make_key("foo"), make_value()), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsTrue())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key("bar"), make_value(1)), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_TRUE(map_builder.Has(key)); + EXPECT_FALSE(map_builder.empty()); + EXPECT_EQ(map_builder.size(), 3); + EXPECT_EQ(map_builder.DebugString(), "{\"\": 0, \"foo\": 0, \"bar\": 1}"); + ASSERT_OK_AND_ASSIGN(auto map, std::move(map_builder).Build()); + EXPECT_FALSE(map->empty()); + EXPECT_EQ(map->size(), 3); + ASSERT_OK_AND_ASSIGN(auto entry, + map->Get(MapValue::GetContext(value_factory), key)); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 0); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key("foo"))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 0); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key("bar"))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->value(), 1); + EXPECT_EQ(map->DebugString(), "{\"\": 0, \"foo\": 0, \"bar\": 1}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[\"\", \"foo\", \"bar\"]"); +} + +TEST(MapValueBuilder, SpecializedUnspecialized) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto map_builder = MapValueBuilder( + value_factory, type_factory.GetIntType(), type_factory.GetStringType()); + auto make_key = [&](int64_t value = 0) -> Handle { + return value_factory.CreateIntValue(value); + }; + auto key = make_key(); + auto make_value = [&](absl::string_view value = + absl::string_view()) -> Handle { + return value_factory.CreateStringValue(value).value(); + }; + auto value = make_value(); + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Insert(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(make_key(1), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.Insert(make_key(1), make_value()), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsTrue())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key(2), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key(2), make_value("foo")), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_TRUE(map_builder.Has(key)); + EXPECT_FALSE(map_builder.empty()); + EXPECT_EQ(map_builder.size(), 3); + EXPECT_EQ(map_builder.DebugString(), "{0: \"\", 1: \"\", 2: \"foo\"}"); + ASSERT_OK_AND_ASSIGN(auto map, std::move(map_builder).Build()); + EXPECT_FALSE(map->empty()); + EXPECT_EQ(map->size(), 3); + ASSERT_OK_AND_ASSIGN(auto entry, + map->Get(MapValue::GetContext(value_factory), key)); + EXPECT_TRUE((*entry)->Is()); + EXPECT_TRUE((*entry).As()->empty()); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key(1))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_TRUE((*entry).As()->empty()); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key(2))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->ToString(), "foo"); + EXPECT_EQ(map->DebugString(), "{0: \"\", 1: \"\", 2: \"foo\"}"); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + EXPECT_EQ(keys->DebugString(), "[0, 1, 2]"); +} + +TEST(MapValueBuilder, SpecializedGeneric) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + auto map_builder = MapValueBuilder( + value_factory, type_factory.GetIntType(), type_factory.GetStringType()); + auto make_key = [&](int64_t value = 0) -> Handle { + return value_factory.CreateIntValue(value); + }; + auto key = make_key(); + auto make_value = [&](absl::string_view value = + absl::string_view()) -> Handle { + return value_factory.CreateStringValue(value).value(); + }; + auto value = make_value(); + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Insert(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.Insert(make_key(1), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.Insert(make_key(1), make_value()), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_THAT(map_builder.Update(key, value), + IsOkAndHolds(IsTrue())); // lvalue, lvalue + EXPECT_THAT(map_builder.Update(key, make_value()), + IsOkAndHolds(IsTrue())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, value), + IsOkAndHolds(IsFalse())); // lvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value()), + IsOkAndHolds(IsFalse())); // lvalue, rvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key(2), value), + IsOkAndHolds(IsTrue())); // rvalue, lvalue + EXPECT_THAT(map_builder.InsertOrAssign(make_key(2), make_value("foo")), + IsOkAndHolds(IsFalse())); // rvalue, rvalue + EXPECT_TRUE(map_builder.Has(key)); + EXPECT_FALSE(map_builder.empty()); + EXPECT_EQ(map_builder.size(), 3); + EXPECT_EQ(map_builder.DebugString(), "{0: \"\", 1: \"\", 2: \"foo\"}"); + ASSERT_OK_AND_ASSIGN(auto map, std::move(map_builder).Build()); + EXPECT_FALSE(map->empty()); + EXPECT_EQ(map->size(), 3); + ASSERT_OK_AND_ASSIGN(auto entry, + map->Get(MapValue::GetContext(value_factory), key)); + EXPECT_TRUE((*entry)->Is()); + EXPECT_TRUE((*entry).As()->empty()); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key(1))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_TRUE((*entry).As()->empty()); + ASSERT_OK_AND_ASSIGN( + entry, map->Get(MapValue::GetContext(value_factory), make_key(2))); + EXPECT_TRUE((*entry)->Is()); + EXPECT_EQ((*entry).As()->ToString(), "foo"); + EXPECT_EQ(map->DebugString(), "{0: \"\", 1: \"\", 2: \"foo\"}"); } template void TestMapBuilder(GetKey get_key, GetValue get_value, MakeKey make_key1, MakeKey make_key2, MakeKey make_key3, MakeValue make_value1, - MakeValue make_value2, absl::string_view debug_string) { + MakeValue make_value2, absl::string_view debug_string, + absl::string_view keys_debug_string) { TypeFactory type_factory(MemoryManager::Global()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); @@ -312,13 +592,13 @@ void TestMapBuilder(GetKey get_key, GetValue get_value, MakeKey make_key1, IsOkAndHolds(IsTrue())); // lvalue, lvalue EXPECT_THAT(map_builder.Update(key, make_value1(value_factory)), IsOkAndHolds(IsTrue())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, value), + EXPECT_THAT(map_builder.InsertOrAssign(key, value), IsOkAndHolds(IsFalse())); // lvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(key, make_value1(value_factory)), + EXPECT_THAT(map_builder.InsertOrAssign(key, make_value1(value_factory)), IsOkAndHolds(IsFalse())); // lvalue, rvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key3(value_factory), value), + EXPECT_THAT(map_builder.InsertOrAssign(make_key3(value_factory), value), IsOkAndHolds(IsTrue())); // rvalue, lvalue - EXPECT_THAT(map_builder.InsertOrUpdate(make_key3(value_factory), + EXPECT_THAT(map_builder.InsertOrAssign(make_key3(value_factory), make_value2(value_factory)), IsOkAndHolds(IsFalse())); // rvalue, rvalue EXPECT_TRUE(map_builder.Has(key)); @@ -344,6 +624,22 @@ void TestMapBuilder(GetKey get_key, GetValue get_value, MakeKey make_key1, EXPECT_EQ(*((*entry).template As()), *((make_value2(value_factory)).template As())); EXPECT_EQ(map->DebugString(), debug_string); + ASSERT_OK_AND_ASSIGN(auto keys, + map->ListKeys(MapValue::ListKeysContext(value_factory))); + EXPECT_FALSE(keys->empty()); + EXPECT_EQ(keys->size(), 3); + ASSERT_OK_AND_ASSIGN(auto element, + keys->Get(ListValue::GetContext(value_factory), 0)); + EXPECT_EQ((*element).template As(), (*key).template As()); + ASSERT_OK_AND_ASSIGN(element, + keys->Get(ListValue::GetContext(value_factory), 1)); + EXPECT_EQ((*element).template As(), + (*make_key2(value_factory)).template As()); + ASSERT_OK_AND_ASSIGN(element, + keys->Get(ListValue::GetContext(value_factory), 2)); + EXPECT_EQ((*element).template As(), + (*make_key3(value_factory)).template As()); + EXPECT_EQ(keys->DebugString(), keys_debug_string); } template @@ -383,35 +679,35 @@ TEST(MapValueBuilder, IntBool) { TestMapBuilder( &TypeFactory::GetIntType, &TypeFactory::GetBoolType, MakeIntValue<0>, MakeIntValue<1>, MakeIntValue<2>, MakeBoolValue, - MakeBoolValue, "{0: false, 1: false, 2: true}"); + MakeBoolValue, "{0: false, 1: false, 2: true}", "[0, 1, 2]"); } TEST(MapValueBuilder, IntInt) { TestMapBuilder( &TypeFactory::GetIntType, &TypeFactory::GetIntType, MakeIntValue<0>, MakeIntValue<1>, MakeIntValue<2>, MakeIntValue<0>, MakeIntValue<1>, - "{0: 0, 1: 0, 2: 1}"); + "{0: 0, 1: 0, 2: 1}", "[0, 1, 2]"); } TEST(MapValueBuilder, IntUint) { TestMapBuilder( &TypeFactory::GetIntType, &TypeFactory::GetUintType, MakeIntValue<0>, MakeIntValue<1>, MakeIntValue<2>, MakeUintValue<0>, MakeUintValue<1>, - "{0: 0u, 1: 0u, 2: 1u}"); + "{0: 0u, 1: 0u, 2: 1u}", "[0, 1, 2]"); } TEST(MapValueBuilder, IntDouble) { TestMapBuilder( &TypeFactory::GetIntType, &TypeFactory::GetDoubleType, MakeIntValue<0>, MakeIntValue<1>, MakeIntValue<2>, MakeDoubleValue(0.0), - MakeDoubleValue(1.0), "{0: 0.0, 1: 0.0, 2: 1.0}"); + MakeDoubleValue(1.0), "{0: 0.0, 1: 0.0, 2: 1.0}", "[0, 1, 2]"); } TEST(MapValueBuilder, IntDuration) { TestMapBuilder( &TypeFactory::GetIntType, &TypeFactory::GetDurationType, MakeIntValue<0>, MakeIntValue<1>, MakeIntValue<2>, MakeDurationValue(absl::ZeroDuration()), - MakeDurationValue(absl::Seconds(1)), "{0: 0, 1: 0, 2: 1s}"); + MakeDurationValue(absl::Seconds(1)), "{0: 0, 1: 0, 2: 1s}", "[0, 1, 2]"); } TEST(MapValueBuilder, IntTimestamp) { @@ -420,35 +716,36 @@ TEST(MapValueBuilder, IntTimestamp) { MakeIntValue<1>, MakeIntValue<2>, MakeTimestampValue(absl::UnixEpoch()), MakeTimestampValue(absl::UnixEpoch() + absl::Seconds(1)), "{0: 1970-01-01T00:00:00Z, 1: 1970-01-01T00:00:00Z, 2: " - "1970-01-01T00:00:01Z}"); + "1970-01-01T00:00:01Z}", + "[0, 1, 2]"); } TEST(MapValueBuilder, UintBool) { TestMapBuilder( &TypeFactory::GetUintType, &TypeFactory::GetBoolType, MakeUintValue<0>, MakeUintValue<1>, MakeUintValue<2>, MakeBoolValue, - MakeBoolValue, "{0u: false, 1u: false, 2u: true}"); + MakeBoolValue, "{0u: false, 1u: false, 2u: true}", "[0u, 1u, 2u]"); } TEST(MapValueBuilder, UintInt) { TestMapBuilder( &TypeFactory::GetUintType, &TypeFactory::GetIntType, MakeUintValue<0>, MakeUintValue<1>, MakeUintValue<2>, MakeIntValue<0>, MakeIntValue<1>, - "{0u: 0, 1u: 0, 2u: 1}"); + "{0u: 0, 1u: 0, 2u: 1}", "[0u, 1u, 2u]"); } TEST(MapValueBuilder, UintUint) { TestMapBuilder( &TypeFactory::GetUintType, &TypeFactory::GetUintType, MakeUintValue<0>, MakeUintValue<1>, MakeUintValue<2>, MakeUintValue<0>, MakeUintValue<1>, - "{0u: 0u, 1u: 0u, 2u: 1u}"); + "{0u: 0u, 1u: 0u, 2u: 1u}", "[0u, 1u, 2u]"); } TEST(MapValueBuilder, UintDouble) { TestMapBuilder( &TypeFactory::GetUintType, &TypeFactory::GetDoubleType, MakeUintValue<0>, MakeUintValue<1>, MakeUintValue<2>, MakeDoubleValue(0.0), - MakeDoubleValue(1.0), "{0u: 0.0, 1u: 0.0, 2u: 1.0}"); + MakeDoubleValue(1.0), "{0u: 0.0, 1u: 0.0, 2u: 1.0}", "[0u, 1u, 2u]"); } TEST(MapValueBuilder, UintDuration) { @@ -456,7 +753,8 @@ TEST(MapValueBuilder, UintDuration) { &TypeFactory::GetUintType, &TypeFactory::GetDurationType, MakeUintValue<0>, MakeUintValue<1>, MakeUintValue<2>, MakeDurationValue(absl::ZeroDuration()), - MakeDurationValue(absl::Seconds(1)), "{0u: 0, 1u: 0, 2u: 1s}"); + MakeDurationValue(absl::Seconds(1)), "{0u: 0, 1u: 0, 2u: 1s}", + "[0u, 1u, 2u]"); } TEST(MapValueBuilder, UintTimestamp) { @@ -466,10 +764,9 @@ TEST(MapValueBuilder, UintTimestamp) { MakeTimestampValue(absl::UnixEpoch()), MakeTimestampValue(absl::UnixEpoch() + absl::Seconds(1)), "{0u: 1970-01-01T00:00:00Z, 1u: 1970-01-01T00:00:00Z, 2u: " - "1970-01-01T00:00:01Z}"); + "1970-01-01T00:00:01Z}", + "[0u, 1u, 2u]"); } -// TODO(issues/5): add Generic, Generic, and friends - } // namespace } // namespace cel From 4c3566587ff15d39237889ddbabb635c16f93766 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 18 May 2023 19:32:40 +0000 Subject: [PATCH 291/303] Add missing aliases to the wrapper types PiperOrigin-RevId: 533215268 --- base/type.cc | 2 ++ base/types/wrapper_type.cc | 34 ++++++++++++++++++++++++++++++++++ base/types/wrapper_type.h | 17 +++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/base/type.cc b/base/type.cc index dcdda19b1..281318422 100644 --- a/base/type.cc +++ b/base/type.cc @@ -101,6 +101,8 @@ absl::Span Type::aliases() const { return static_cast(this)->aliases(); case Kind::kMap: return static_cast(this)->aliases(); + case Kind::kWrapper: + return static_cast(this)->aliases(); default: // Everything else does not support aliases. return absl::Span(); diff --git a/base/types/wrapper_type.cc b/base/types/wrapper_type.cc index c7ec24106..2de285528 100644 --- a/base/types/wrapper_type.cc +++ b/base/types/wrapper_type.cc @@ -15,6 +15,8 @@ #include "base/types/wrapper_type.h" #include "absl/base/optimization.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" namespace cel { @@ -46,6 +48,20 @@ absl::string_view WrapperType::name() const { } } +absl::Span WrapperType::aliases() const { + switch (base_internal::Metadata::GetInlineVariant(*this)) { + case Kind::kDouble: + return static_cast(this)->aliases(); + case Kind::kInt: + return static_cast(this)->aliases(); + case Kind::kUint: + return static_cast(this)->aliases(); + default: + // The other wrappers do not have aliases. + return absl::Span(); + } +} + const Handle& WrapperType::wrapped() const { switch (base_internal::Metadata::GetInlineVariant(*this)) { case Kind::kBool: @@ -66,4 +82,22 @@ const Handle& WrapperType::wrapped() const { } } +absl::Span DoubleWrapperType::aliases() const { + static constexpr absl::string_view kAliases[] = { + "google.protobuf.FloatValue"}; + return absl::MakeConstSpan(kAliases); +} + +absl::Span IntWrapperType::aliases() const { + static constexpr absl::string_view kAliases[] = { + "google.protobuf.Int32Value"}; + return absl::MakeConstSpan(kAliases); +} + +absl::Span UintWrapperType::aliases() const { + static constexpr absl::string_view kAliases[] = { + "google.protobuf.UInt32Value"}; + return absl::MakeConstSpan(kAliases); +} + } // namespace cel diff --git a/base/types/wrapper_type.h b/base/types/wrapper_type.h index 9de378010..1b74379e3 100644 --- a/base/types/wrapper_type.h +++ b/base/types/wrapper_type.h @@ -22,6 +22,7 @@ #include "absl/base/attributes.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/type.h" @@ -69,6 +70,7 @@ class WrapperType : public Type, base_internal::InlineData { const Handle& wrapped() const; private: + friend class Type; friend class BoolWrapperType; friend class BytesWrapperType; friend class DoubleWrapperType; @@ -76,6 +78,9 @@ class WrapperType : public Type, base_internal::InlineData { friend class StringWrapperType; friend class UintWrapperType; + // See Type::aliases(). + absl::Span aliases() const; + using Base::Base; }; @@ -189,6 +194,7 @@ class DoubleWrapperType final : public WrapperType { const Handle& wrapped() const { return DoubleType::Get(); } private: + friend class WrapperType; friend class TypeFactory; template friend struct base_internal::AnyData; @@ -202,6 +208,9 @@ class DoubleWrapperType final : public WrapperType { << base_internal::kInlineVariantShift); constexpr DoubleWrapperType() : WrapperType(kMetadata) {} + + // See Type::aliases(). + absl::Span aliases() const; }; class IntWrapperType final : public WrapperType { @@ -226,6 +235,7 @@ class IntWrapperType final : public WrapperType { const Handle& wrapped() const { return IntType::Get(); } private: + friend class WrapperType; friend class TypeFactory; template friend struct base_internal::AnyData; @@ -239,6 +249,9 @@ class IntWrapperType final : public WrapperType { << base_internal::kInlineVariantShift); constexpr IntWrapperType() : WrapperType(kMetadata) {} + + // See Type::aliases(). + absl::Span aliases() const; }; class StringWrapperType final : public WrapperType { @@ -300,6 +313,7 @@ class UintWrapperType final : public WrapperType { const Handle& wrapped() const { return UintType::Get(); } private: + friend class WrapperType; friend class TypeFactory; template friend struct base_internal::AnyData; @@ -313,6 +327,9 @@ class UintWrapperType final : public WrapperType { << base_internal::kInlineVariantShift); constexpr UintWrapperType() : WrapperType(kMetadata) {} + + // See Type::aliases(). + absl::Span aliases() const; }; extern template class Handle; From c1a562afc77468e41527fb6ac04630fadd4807d2 Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 18 May 2023 19:39:53 +0000 Subject: [PATCH 292/303] Fix the remaining usages of LOG macro. Add "unused" attribute to public headers. PiperOrigin-RevId: 533217495 --- eval/compiler/constant_folding.cc | 2 +- eval/eval/evaluator_core.cc | 6 +++--- eval/eval/evaluator_stack.h | 16 ++++++++-------- eval/eval/select_step.cc | 2 +- eval/public/cel_expr_builder_factory.cc | 8 ++++---- eval/public/portable_cel_expr_builder_factory.cc | 4 ++-- eval/public/structs/field_access_impl.cc | 2 +- eval/public/structs/legacy_type_provider.h | 5 +++-- 8 files changed, 23 insertions(+), 22 deletions(-) diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 6ddf660b7..db7b6f7e0 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -359,7 +359,7 @@ class ConstantFoldingTransform { } bool operator()(absl::monostate) { - LOG(ERROR) << "Unsupported Expr kind"; + ABSL_LOG(ERROR) << "Unsupported Expr kind"; return false; } diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index a35986c3c..9e0683b5e 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -51,7 +51,7 @@ const ExpressionStep* ExecutionFrame::Next() { if (pc_ < end_pos) return execution_path_[pc_++].get(); if (pc_ > end_pos) { - LOG(ERROR) << "Attempting to step beyond the end of execution path."; + ABSL_LOG(ERROR) << "Attempting to step beyond the end of execution path."; } return nullptr; } @@ -167,8 +167,8 @@ absl::StatusOr> ExecutionFrame::Evaluate( } if (value_stack().empty()) { - LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " - "Try to disable short-circuiting."; + ABSL_LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " + "Try to disable short-circuiting."; continue; } CEL_RETURN_IF_ERROR( diff --git a/eval/eval/evaluator_stack.h b/eval/eval/evaluator_stack.h index 63ace19fb..b7f8f5420 100644 --- a/eval/eval/evaluator_stack.h +++ b/eval/eval/evaluator_stack.h @@ -46,8 +46,8 @@ class EvaluatorStack { // Please note that calls to Push may invalidate returned Span object. absl::Span> GetSpan(size_t size) const { if (!HasEnough(size)) { - LOG(ERROR) << "Requested span size (" << size - << ") exceeds current stack size: " << current_size_; + ABSL_LOG(ERROR) << "Requested span size (" << size + << ") exceeds current stack size: " << current_size_; } return absl::Span>( stack_.data() + current_size_ - size, size); @@ -65,7 +65,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. const cel::Handle& Peek() const { if (empty()) { - LOG(ERROR) << "Peeking on empty EvaluatorStack"; + ABSL_LOG(ERROR) << "Peeking on empty EvaluatorStack"; } return stack_[current_size_ - 1]; } @@ -74,7 +74,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. const AttributeTrail& PeekAttribute() const { if (empty()) { - LOG(ERROR) << "Peeking on empty EvaluatorStack"; + ABSL_LOG(ERROR) << "Peeking on empty EvaluatorStack"; } return attribute_stack_[current_size_ - 1]; } @@ -83,8 +83,8 @@ class EvaluatorStack { // Checking that stack has enough elements is caller's responsibility. void Pop(size_t size) { if (!HasEnough(size)) { - LOG(ERROR) << "Trying to pop more elements (" << size - << ") than the current stack size: " << current_size_; + ABSL_LOG(ERROR) << "Trying to pop more elements (" << size + << ") than the current stack size: " << current_size_; } while (size > 0) { stack_.pop_back(); @@ -101,7 +101,7 @@ class EvaluatorStack { void Push(cel::Handle value, AttributeTrail attribute) { if (current_size_ >= max_size()) { - LOG(ERROR) << "No room to push more elements on to EvaluatorStack"; + ABSL_LOG(ERROR) << "No room to push more elements on to EvaluatorStack"; } stack_.push_back(std::move(value)); attribute_stack_.push_back(std::move(attribute)); @@ -118,7 +118,7 @@ class EvaluatorStack { // Checking that stack is not empty is caller's responsibility. void PopAndPush(cel::Handle value, AttributeTrail attribute) { if (empty()) { - LOG(ERROR) << "Cannot PopAndPush on empty stack."; + ABSL_LOG(ERROR) << "Cannot PopAndPush on empty stack."; } stack_[current_size_ - 1] = std::move(value); attribute_stack_[current_size_ - 1] = std::move(attribute); diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 08635e8a4..07c0b93e5 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -116,7 +116,7 @@ absl::optional> CheckForMarkedAttributes( } // Invariant broken (an invalid CEL Attribute shouldn't match anything). // Log and return a CelError. - LOG(ERROR) + ABSL_LOG(ERROR) << "Invalid attribute pattern matched select path: " << attribute_string.status().ToString(); // NOLINT: OSS compatibility return CreateErrorValueFromView(Arena::Create( diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 679d60e38..b0eda9a55 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -39,13 +39,13 @@ std::unique_ptr CreateCelExpressionBuilder( google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { if (descriptor_pool == nullptr) { - LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " - "CreateCelExpressionBuilder"; + ABSL_LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " + "CreateCelExpressionBuilder"; return nullptr; } if (auto s = ValidateStandardMessageTypes(*descriptor_pool); !s.ok()) { - LOG(WARNING) << "Failed to validate standard message types: " - << s.ToString(); // NOLINT: OSS compatibility + ABSL_LOG(WARNING) << "Failed to validate standard message types: " + << s.ToString(); // NOLINT: OSS compatibility return nullptr; } diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index f22298ea1..d920a2125 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -34,8 +34,8 @@ std::unique_ptr CreatePortableExprBuilder( std::unique_ptr type_provider, const InterpreterOptions& options) { if (type_provider == nullptr) { - LOG(ERROR) << "Cannot pass nullptr as type_provider to " - "CreatePortableExprBuilder"; + ABSL_LOG(ERROR) << "Cannot pass nullptr as type_provider to " + "CreatePortableExprBuilder"; return nullptr; } cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 16233c545..7cc64fadb 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -599,7 +599,7 @@ class ScalarFieldSetter : public FieldSetter { bool SetMessage(const Message* value) const { if (!value) { - LOG(ERROR) << "Message is NULL"; + ABSL_LOG(ERROR) << "Message is NULL"; return true; } if (value->GetDescriptor()->full_name() == diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h index 3bb22e443..a563c73a0 100644 --- a/eval/public/structs/legacy_type_provider.h +++ b/eval/public/structs/legacy_type_provider.h @@ -46,7 +46,7 @@ class LegacyTypeProvider : public cel::TypeProvider { // created ones, the TypeInfoApis returned from this method should be the same // as the ones used in value creation. virtual absl::optional ProvideLegacyTypeInfo( - absl::string_view name) const { + ABSL_ATTRIBUTE_UNUSED absl::string_view name) const { return absl::nullopt; } @@ -61,7 +61,8 @@ class LegacyTypeProvider : public cel::TypeProvider { // TODO(issues/5): Move protobuf-Any API from top level // [Legacy]TypeProviders. virtual absl::optional - ProvideLegacyAnyPackingApis(absl::string_view name) const { + ProvideLegacyAnyPackingApis( + ABSL_ATTRIBUTE_UNUSED absl::string_view name) const { return absl::nullopt; } }; From c2d4b175d5c315e0b69b4d9d73f52a0a6ce26638 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 19 May 2023 03:17:40 +0000 Subject: [PATCH 293/303] Internal change PiperOrigin-RevId: 533327502 --- .bazelversion | 2 +- .gitignore | 5 +- eval/internal/interop.cc | 13 +- eval/internal/interop.h | 5 +- eval/internal/interop_test.cc | 7 +- eval/public/BUILD | 7 +- eval/public/builtin_func_registrar.cc | 231 ++++++++++++-------------- 7 files changed, 125 insertions(+), 145 deletions(-) diff --git a/.bazelversion b/.bazelversion index dfda3e0b4..6abaeb2f9 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.1.0 +6.2.0 diff --git a/.gitignore b/.gitignore index 6d3e1b8bb..2eb327820 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ -# bazel produces these as symlinks, not directories bazel-bin -bazel-cel-cpp +bazel-eval bazel-genfiles bazel-out bazel-testlogs +bazel-cel-cpp +*~ diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index c3fb8d082..4bc2a55fa 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -437,16 +437,19 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, break; case Kind::kAny: break; - case Kind::kType: + case Kind::kType: { // Should be fine, so long as we are using an arena allocator. // We can only transport legacy type values. if (base_internal::Metadata::GetInlineVariant< - base_internal::InlinedTypeValueVariant>(*value) != + base_internal::InlinedTypeValueVariant>(*value) == base_internal::InlinedTypeValueVariant::kLegacy) { - return absl::UnimplementedError( - "only legacy type values can be used for interop"); + return CelValue::CreateCelTypeView(value.As()->name()); } - return CelValue::CreateCelTypeView(value.As()->name()); + auto* type_name = google::protobuf::Arena::Create( + arena, value.As()->name()); + + return CelValue::CreateCelTypeView(*type_name); + } case Kind::kBool: return CelValue::CreateBool(value.As()->value()); case Kind::kInt: diff --git a/eval/internal/interop.h b/eval/internal/interop.h index eb9d28396..3791b1895 100644 --- a/eval/internal/interop.h +++ b/eval/internal/interop.h @@ -27,6 +27,7 @@ #include "absl/types/variant.h" #include "base/value.h" #include "base/value_factory.h" +#include "base/values/type_value.h" #include "eval/public/cel_value.h" #include "eval/public/message_wrapper.h" @@ -131,7 +132,7 @@ Handle CreateUnknownValueFromView( const base_internal::UnknownSet* value); // Convert a legacy value to a modern value, CHECK failing if its not possible. -// This should only be used during rewritting of the evaluator when it is +// This should only be used during rewriting of the evaluator when it is // guaranteed that all modern and legacy values are interoperable, and the // memory manager is google::protobuf::Arena. Handle LegacyValueToModernValueOrDie( @@ -150,7 +151,7 @@ std::vector> LegacyValueToModernValueOrDie( bool unchecked = false); // Convert a modern value to a legacy value, CHECK failing if its not possible. -// This should only be used during rewritting of the evaluator when it is +// This should only be used during rewriting of the evaluator when it is // guaranteed that all modern and legacy values are interoperable, and the // memory manager is google::protobuf::Arena. google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index f7747c8c5..50c29bda5 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -279,15 +279,16 @@ TEST(ValueInterop, TypeToLegacy) { EXPECT_EQ(legacy_value.CelTypeOrDie().value(), "struct.that.does.not.Exist"); } -TEST(ValueInterop, ModernTypeUnimplemented) { +TEST(ValueInterop, ModernTypeToStringView) { google::protobuf::Arena arena; extensions::ProtoMemoryManager memory_manager(&arena); TypeFactory type_factory(memory_manager); TypeManager type_manager(type_factory, TypeProvider::Builtin()); ValueFactory value_factory(type_manager); auto value = value_factory.CreateTypeValue(type_factory.GetBoolType()); - EXPECT_THAT(ToLegacyValue(&arena, value), - StatusIs(absl::StatusCode::kUnimplemented)); + ASSERT_OK_AND_ASSIGN(CelValue legacy_value, ToLegacyValue(&arena, value)); + ASSERT_TRUE(legacy_value.IsCelType()); + EXPECT_EQ(legacy_value.CelTypeOrDie().value(), "bool"); } TEST(ValueInterop, StringFromLegacy) { diff --git a/eval/public/BUILD b/eval/public/BUILD index 65104bc69..fa82d9494 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -282,7 +282,6 @@ cc_library( "builtin_func_registrar.h", ], deps = [ - ":cel_builtins", ":cel_function", ":cel_function_registry", ":cel_number", @@ -293,15 +292,11 @@ cc_library( ":equality_function_registrar", ":logical_function_registrar", ":portable_cel_function_adapter", + "//base:builtins", "//base:function_adapter", "//base:handle", - "//base:type", "//base:value", - "//eval/eval:mutable_list_impl", "//eval/internal:interop", - "//eval/public/containers:container_backed_list_impl", - "//extensions/protobuf:memory_manager", - "//internal:casts", "//internal:overflow", "//internal:proto_time_encoding", "//internal:status_macros", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 451312e80..1c0955e0e 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -30,30 +30,25 @@ #include "absl/time/civil_time.h" #include "absl/time/time.h" #include "absl/types/optional.h" +#include "base/builtins.h" #include "base/function_adapter.h" #include "base/handle.h" -#include "base/type_factory.h" #include "base/value.h" #include "base/value_factory.h" #include "base/values/bytes_value.h" #include "base/values/list_value.h" #include "base/values/map_value.h" #include "base/values/string_value.h" -#include "eval/eval/mutable_list_impl.h" #include "eval/internal/interop.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" #include "eval/public/container_function_registrar.h" -#include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/equality_function_registrar.h" #include "eval/public/logical_function_registrar.h" #include "eval/public/portable_cel_function_adapter.h" -#include "extensions/protobuf/memory_manager.h" -#include "internal/casts.h" #include "internal/overflow.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" @@ -68,8 +63,6 @@ namespace { using ::cel::BinaryFunctionAdapter; using ::cel::BytesValue; using ::cel::Handle; -using ::cel::ListValue; -using ::cel::MapValue; using ::cel::StringValue; using ::cel::UnaryFunctionAdapter; using ::cel::Value; @@ -235,19 +228,19 @@ template absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) { using FunctionAdapter = cel::BinaryFunctionAdapter, Type, Type>; CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kAdd, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false), FunctionAdapter::WrapFunction(&Add))); CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kSubtract, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false), FunctionAdapter::WrapFunction(&Sub))); CEL_RETURN_IF_ERROR(registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kMultiply, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false), FunctionAdapter::WrapFunction(&Mul))); return registry->Register( - FunctionAdapter::CreateDescriptor(builtin::kDivide, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false), FunctionAdapter::WrapFunction(&Div)); } @@ -261,13 +254,13 @@ absl::Status RegisterNumericArithmeticFunctions( // Modulo CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, int64_t, int64_t>::CreateDescriptor( - builtin::kModulo, false), + cel::builtin::kModulo, false), BinaryFunctionAdapter, int64_t, int64_t>::WrapFunction( &Modulo))); CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, uint64_t, - uint64_t>::CreateDescriptor(builtin::kModulo, + uint64_t>::CreateDescriptor(cel::builtin::kModulo, false), BinaryFunctionAdapter, uint64_t, uint64_t>::WrapFunction( &Modulo))); @@ -275,7 +268,7 @@ absl::Status RegisterNumericArithmeticFunctions( // Negation group CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, int64_t>::CreateDescriptor( - builtin::kNeg, false), + cel::builtin::kNeg, false), UnaryFunctionAdapter, int64_t>::WrapFunction( [](ValueFactory& value_factory, int64_t value) -> Handle { auto inv = cel::internal::CheckedNegation(value); @@ -286,7 +279,7 @@ absl::Status RegisterNumericArithmeticFunctions( }))); return registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kNeg, + UnaryFunctionAdapter::CreateDescriptor(cel::builtin::kNeg, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, double value) -> double { return -value; })); @@ -547,9 +540,9 @@ bool StringStartsWith(ValueFactory&, const StringValue& value, absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { constexpr std::array in_operators = { - builtin::kIn, // @in for map and list types. - builtin::kInFunction, // deprecated in() -- for backwards compat - builtin::kInDeprecated, // deprecated _in_ -- for backwards compat + cel::builtin::kIn, // @in for map and list types. + cel::builtin::kInFunction, // deprecated in() -- for backwards compat + cel::builtin::kInDeprecated, // deprecated _in_ -- for backwards compat }; if (options.enable_list_contains) { @@ -742,7 +735,7 @@ absl::Status RegisterRegexFunctions(CelFunctionRegistry* registry, const StringValue&>; CEL_RETURN_IF_ERROR( registry->Register(MatchFnAdapter::CreateDescriptor( - builtin::kRegexMatch, receiver_style), + cel::builtin::kRegexMatch, receiver_style), MatchFnAdapter::WrapFunction(regex_matches))); } } // if options.enable_regex @@ -756,19 +749,19 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, for (bool receiver_style : {true, false}) { CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kStringContains, receiver_style), + CreateDescriptor(cel::builtin::kStringContains, receiver_style), BinaryFunctionAdapter:: WrapFunction(StringContains))); CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kStringEndsWith, receiver_style), + CreateDescriptor(cel::builtin::kStringEndsWith, receiver_style), BinaryFunctionAdapter:: WrapFunction(StringEndsWith))); CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kStringStartsWith, receiver_style), + CreateDescriptor(cel::builtin::kStringStartsWith, receiver_style), BinaryFunctionAdapter:: WrapFunction(StringStartsWith))); } @@ -779,14 +772,14 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, BinaryFunctionAdapter>, const StringValue&, const StringValue&>; CEL_RETURN_IF_ERROR(registry->Register( - StrCatFnAdapter::CreateDescriptor(builtin::kAdd, false), + StrCatFnAdapter::CreateDescriptor(cel::builtin::kAdd, false), StrCatFnAdapter::WrapFunction(&ConcatString))); using BytesCatFnAdapter = BinaryFunctionAdapter>, const BytesValue&, const BytesValue&>; CEL_RETURN_IF_ERROR(registry->Register( - BytesCatFnAdapter::CreateDescriptor(builtin::kAdd, false), + BytesCatFnAdapter::CreateDescriptor(cel::builtin::kAdd, false), BytesCatFnAdapter::WrapFunction(&ConcatBytes))); } @@ -807,11 +800,11 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, UnaryFunctionAdapter, const StringValue&>; CEL_RETURN_IF_ERROR( registry->Register(StrSizeFnAdapter::CreateDescriptor( - builtin::kSize, /*receiver_style=*/true), + cel::builtin::kSize, /*receiver_style=*/true), StrSizeFnAdapter::WrapFunction(size_func))); CEL_RETURN_IF_ERROR( registry->Register(StrSizeFnAdapter::CreateDescriptor( - builtin::kSize, /*receiver_style=*/false), + cel::builtin::kSize, /*receiver_style=*/false), StrSizeFnAdapter::WrapFunction(size_func))); // Bytes size @@ -823,11 +816,11 @@ absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, using BytesSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR( registry->Register(BytesSizeFnAdapter::CreateDescriptor( - builtin::kSize, /*receiver_style=*/true), + cel::builtin::kSize, /*receiver_style=*/true), BytesSizeFnAdapter::WrapFunction(bytes_size_func))); CEL_RETURN_IF_ERROR( registry->Register(BytesSizeFnAdapter::CreateDescriptor( - builtin::kSize, /*receiver_style=*/false), + cel::builtin::kSize, /*receiver_style=*/false), BytesSizeFnAdapter::WrapFunction(bytes_size_func))); return absl::OkStatus(); @@ -837,7 +830,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kFullYear, true), + CreateDescriptor(cel::builtin::kFullYear, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -846,7 +839,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kFullYear, true), + cel::builtin::kFullYear, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetFullYear(value_factory, ts, ""); @@ -854,7 +847,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kMonth, true), + CreateDescriptor(cel::builtin::kMonth, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -863,7 +856,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kMonth, true), + cel::builtin::kMonth, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetMonth(value_factory, ts, ""); @@ -871,7 +864,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kDayOfYear, true), + CreateDescriptor(cel::builtin::kDayOfYear, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -880,7 +873,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kDayOfYear, true), + cel::builtin::kDayOfYear, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetDayOfYear(value_factory, ts, ""); @@ -888,7 +881,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kDayOfMonth, true), + CreateDescriptor(cel::builtin::kDayOfMonth, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -897,7 +890,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kDayOfMonth, true), + cel::builtin::kDayOfMonth, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetDayOfMonth(value_factory, ts, ""); @@ -905,7 +898,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kDate, true), + CreateDescriptor(cel::builtin::kDate, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -914,7 +907,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kDate, true), + cel::builtin::kDate, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetDate(value_factory, ts, ""); @@ -922,7 +915,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kDayOfWeek, true), + CreateDescriptor(cel::builtin::kDayOfWeek, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -931,7 +924,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kDayOfWeek, true), + cel::builtin::kDayOfWeek, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetDayOfWeek(value_factory, ts, ""); @@ -939,7 +932,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kHours, true), + CreateDescriptor(cel::builtin::kHours, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -948,7 +941,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kHours, true), + cel::builtin::kHours, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetHours(value_factory, ts, ""); @@ -956,7 +949,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kMinutes, true), + CreateDescriptor(cel::builtin::kMinutes, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -965,7 +958,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kMinutes, true), + cel::builtin::kMinutes, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetMinutes(value_factory, ts, ""); @@ -973,7 +966,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kSeconds, true), + CreateDescriptor(cel::builtin::kSeconds, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -982,7 +975,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kSeconds, true), + cel::builtin::kSeconds, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetSeconds(value_factory, ts, ""); @@ -990,7 +983,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, const StringValue&>:: - CreateDescriptor(builtin::kMilliseconds, true), + CreateDescriptor(cel::builtin::kMilliseconds, true), BinaryFunctionAdapter, absl::Time, const StringValue&>:: WrapFunction([](ValueFactory& value_factory, absl::Time ts, const StringValue& tz) -> Handle { @@ -999,7 +992,7 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, return registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kMilliseconds, true), + cel::builtin::kMilliseconds, true), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time ts) -> Handle { return GetMilliseconds(value_factory, ts, ""); @@ -1011,7 +1004,7 @@ absl::Status RegisterBytesConversionFunctions(CelFunctionRegistry* registry, // bytes -> bytes CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, Handle>:: - CreateDescriptor(builtin::kBytes, false), + CreateDescriptor(cel::builtin::kBytes, false), UnaryFunctionAdapter, Handle>:: WrapFunction([](ValueFactory&, Handle value) -> Handle { return value; }))); @@ -1020,7 +1013,7 @@ absl::Status RegisterBytesConversionFunctions(CelFunctionRegistry* registry, return registry->Register( UnaryFunctionAdapter< absl::StatusOr>, - const StringValue&>::CreateDescriptor(builtin::kBytes, false), + const StringValue&>::CreateDescriptor(cel::builtin::kBytes, false), UnaryFunctionAdapter< absl::StatusOr>, const StringValue&>::WrapFunction([](ValueFactory& value_factory, @@ -1034,21 +1027,21 @@ absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, // double -> double CEL_RETURN_IF_ERROR( registry->Register(UnaryFunctionAdapter::CreateDescriptor( - builtin::kDouble, false), + cel::builtin::kDouble, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, double v) { return v; }))); // int -> double CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kDouble, - false), + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kDouble, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, int64_t v) { return static_cast(v); }))); // string -> double CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( - builtin::kDouble, false), + cel::builtin::kDouble, false), UnaryFunctionAdapter, const StringValue&>::WrapFunction( [](ValueFactory& value_factory, const StringValue& s) -> Handle { @@ -1063,8 +1056,8 @@ absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, // uint -> double return registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kDouble, - false), + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kDouble, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, uint64_t v) { return static_cast(v); })); } @@ -1073,7 +1066,7 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // bool -> int CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kInt, + UnaryFunctionAdapter::CreateDescriptor(cel::builtin::kInt, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, bool v) { return static_cast(v); }))); @@ -1081,7 +1074,7 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, // double -> int CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, double>::CreateDescriptor( - builtin::kInt, false), + cel::builtin::kInt, false), UnaryFunctionAdapter, double>::WrapFunction( [](ValueFactory& value_factory, double v) -> Handle { auto conv = cel::internal::CheckedDoubleToInt64(v); @@ -1093,15 +1086,15 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, // int -> int CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kInt, - false), + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kInt, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, int64_t v) { return v; }))); // string -> int CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( - builtin::kInt, false), + cel::builtin::kInt, false), UnaryFunctionAdapter, const StringValue&>::WrapFunction( [](ValueFactory& value_factory, const StringValue& s) -> Handle { @@ -1115,15 +1108,15 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, // time -> int CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kInt, - false), + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kInt, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, absl::Time t) { return absl::ToUnixSeconds(t); }))); // uint -> int return registry->Register( UnaryFunctionAdapter, uint64_t>::CreateDescriptor( - builtin::kInt, false), + cel::builtin::kInt, false), UnaryFunctionAdapter, uint64_t>::WrapFunction( [](ValueFactory& value_factory, uint64_t v) -> Handle { auto conv = cel::internal::CheckedUint64ToInt64(v); @@ -1143,7 +1136,7 @@ absl::Status RegisterStringConversionFunctions( CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const BytesValue&>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, const BytesValue&>::WrapFunction( [](ValueFactory& value_factory, const BytesValue& value) -> Handle { @@ -1157,7 +1150,7 @@ absl::Status RegisterStringConversionFunctions( // double -> string CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, double>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, double>::WrapFunction( [](ValueFactory& value_factory, double value) -> Handle { return value_factory.CreateUncheckedStringValue( @@ -1167,7 +1160,7 @@ absl::Status RegisterStringConversionFunctions( // int -> string CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, int64_t>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, int64_t>::WrapFunction( [](ValueFactory& value_factory, int64_t value) -> Handle { @@ -1178,7 +1171,7 @@ absl::Status RegisterStringConversionFunctions( // string -> string CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, Handle>:: - CreateDescriptor(builtin::kString, false), + CreateDescriptor(cel::builtin::kString, false), UnaryFunctionAdapter, Handle>:: WrapFunction([](ValueFactory&, Handle value) -> Handle { return value; }))); @@ -1186,7 +1179,7 @@ absl::Status RegisterStringConversionFunctions( // uint -> string CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, uint64_t>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, uint64_t>::WrapFunction( [](ValueFactory& value_factory, uint64_t value) -> Handle { @@ -1197,7 +1190,7 @@ absl::Status RegisterStringConversionFunctions( // duration -> string CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, absl::Duration>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, absl::Duration>::WrapFunction( [](ValueFactory& value_factory, absl::Duration value) -> Handle { @@ -1211,7 +1204,7 @@ absl::Status RegisterStringConversionFunctions( // timestamp -> string return registry->Register( UnaryFunctionAdapter, absl::Time>::CreateDescriptor( - builtin::kString, false), + cel::builtin::kString, false), UnaryFunctionAdapter, absl::Time>::WrapFunction( [](ValueFactory& value_factory, absl::Time value) -> Handle { auto encode = EncodeTimeToString(value); @@ -1227,7 +1220,7 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, // double -> uint CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, double>::CreateDescriptor( - builtin::kUint, false), + cel::builtin::kUint, false), UnaryFunctionAdapter, double>::WrapFunction( [](ValueFactory& value_factory, double v) -> Handle { auto conv = cel::internal::CheckedDoubleToUint64(v); @@ -1240,7 +1233,7 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, // int -> uint CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, int64_t>::CreateDescriptor( - builtin::kUint, false), + cel::builtin::kUint, false), UnaryFunctionAdapter, int64_t>::WrapFunction( [](ValueFactory& value_factory, int64_t v) -> Handle { auto conv = cel::internal::CheckedInt64ToUint64(v); @@ -1253,7 +1246,7 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, // string -> uint CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( - builtin::kUint, false), + cel::builtin::kUint, false), UnaryFunctionAdapter, const StringValue&>::WrapFunction( [](ValueFactory& value_factory, const StringValue& s) -> Handle { @@ -1267,8 +1260,8 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, // uint -> uint return registry->Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kUint, - false), + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kUint, false), UnaryFunctionAdapter::WrapFunction( [](ValueFactory&, uint64_t v) { return v; })); } @@ -1282,16 +1275,15 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, // duration() conversion from string. CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( - builtin::kDuration, false), + cel::builtin::kDuration, false), UnaryFunctionAdapter, const StringValue&>::WrapFunction( CreateDurationFromString))); // dyn() identity function. // TODO(issues/102): strip dyn() function references at type-check time. CEL_RETURN_IF_ERROR(registry->Register( - UnaryFunctionAdapter< - Handle, const Handle&>::CreateDescriptor(builtin::kDyn, - false), + UnaryFunctionAdapter, const Handle&>:: + CreateDescriptor(cel::builtin::kDyn, false), UnaryFunctionAdapter, const Handle&>::WrapFunction( [](ValueFactory&, const Handle& value) -> Handle { return value; @@ -1304,7 +1296,7 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, // timestamp conversion from int. CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, int64_t>::CreateDescriptor( - builtin::kTimestamp, false), + cel::builtin::kTimestamp, false), UnaryFunctionAdapter, int64_t>::WrapFunction( [](ValueFactory&, int64_t epoch_seconds) -> Handle { return cel::interop_internal::CreateTimestampValue( @@ -1316,7 +1308,7 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, options.enable_timestamp_duration_overflow_errors; CEL_RETURN_IF_ERROR(registry->Register( UnaryFunctionAdapter, const StringValue&>::CreateDescriptor( - builtin::kTimestamp, false), + cel::builtin::kTimestamp, false), UnaryFunctionAdapter, const StringValue&>::WrapFunction( [=](ValueFactory& value_factory, const StringValue& time_str) -> Handle { @@ -1341,9 +1333,8 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, absl::Status RegisterCheckedTimeArithmeticFunctions( CelFunctionRegistry* registry) { CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter, absl::Time, - absl::Duration>::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter, absl::Time, absl::Duration>:: + CreateDescriptor(cel::builtin::kAdd, false), BinaryFunctionAdapter>, absl::Time, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Time t1, @@ -1357,7 +1348,8 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter>, absl::Duration, - absl::Time>::CreateDescriptor(builtin::kAdd, false), + absl::Time>::CreateDescriptor(cel::builtin::kAdd, + false), BinaryFunctionAdapter>, absl::Duration, absl::Time>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d2, @@ -1370,9 +1362,9 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter>, absl::Duration, - absl::Duration>::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter< + absl::StatusOr>, absl::Duration, + absl::Duration>::CreateDescriptor(cel::builtin::kAdd, false), BinaryFunctionAdapter>, absl::Duration, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d1, @@ -1387,7 +1379,7 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter< absl::StatusOr>, absl::Time, - absl::Duration>::CreateDescriptor(builtin::kSubtract, false), + absl::Duration>::CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter>, absl::Time, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Time t1, @@ -1400,9 +1392,9 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter>, absl::Time, - absl::Time>::CreateDescriptor(builtin::kSubtract, - false), + BinaryFunctionAdapter< + absl::StatusOr>, absl::Time, + absl::Time>::CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter>, absl::Time, absl::Time>:: WrapFunction([](ValueFactory& value_factory, absl::Time t1, @@ -1417,7 +1409,7 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter< absl::StatusOr>, absl::Duration, - absl::Duration>::CreateDescriptor(builtin::kSubtract, false), + absl::Duration>::CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter>, absl::Duration, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d1, @@ -1437,9 +1429,8 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( // TODO(issues/5): deprecate unchecked time math functions when clients no // longer depend on them. CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter, absl::Time, - absl::Duration>::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter, absl::Time, absl::Duration>:: + CreateDescriptor(cel::builtin::kAdd, false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Time t1, absl::Duration d2) -> Handle { @@ -1448,7 +1439,8 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Duration, - absl::Time>::CreateDescriptor(builtin::kAdd, false), + absl::Time>::CreateDescriptor(cel::builtin::kAdd, + false), BinaryFunctionAdapter, absl::Duration, absl::Time>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d2, absl::Time t1) -> Handle { @@ -1456,9 +1448,8 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter, absl::Duration, - absl::Duration>::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter, absl::Duration, absl::Duration>:: + CreateDescriptor(cel::builtin::kAdd, false), BinaryFunctionAdapter, absl::Duration, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d1, absl::Duration d2) -> Handle { @@ -1467,7 +1458,7 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, absl::Duration>:: - CreateDescriptor(builtin::kSubtract, false), + CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction( @@ -1478,9 +1469,8 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry->Register( - BinaryFunctionAdapter, absl::Time, - absl::Time>::CreateDescriptor(builtin::kSubtract, - false), + BinaryFunctionAdapter, absl::Time, absl::Time>:: + CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter, absl::Time, absl::Time>:: WrapFunction( @@ -1491,7 +1481,7 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Duration, absl::Duration>:: - CreateDescriptor(builtin::kSubtract, false), + CreateDescriptor(cel::builtin::kSubtract, false), BinaryFunctionAdapter, absl::Duration, absl::Duration>:: WrapFunction([](ValueFactory& value_factory, absl::Duration d1, absl::Duration d2) -> Handle { @@ -1516,28 +1506,29 @@ absl::Status RegisterTimeFunctions(CelFunctionRegistry* registry, using DurationAccessorFunction = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry->Register( - DurationAccessorFunction::CreateDescriptor(builtin::kHours, true), + DurationAccessorFunction::CreateDescriptor(cel::builtin::kHours, true), DurationAccessorFunction::WrapFunction( [](ValueFactory&, absl::Duration d) -> int64_t { return absl::ToInt64Hours(d); }))); CEL_RETURN_IF_ERROR(registry->Register( - DurationAccessorFunction::CreateDescriptor(builtin::kMinutes, true), + DurationAccessorFunction::CreateDescriptor(cel::builtin::kMinutes, true), DurationAccessorFunction::WrapFunction( [](ValueFactory&, absl::Duration d) -> int64_t { return absl::ToInt64Minutes(d); }))); CEL_RETURN_IF_ERROR(registry->Register( - DurationAccessorFunction::CreateDescriptor(builtin::kSeconds, true), + DurationAccessorFunction::CreateDescriptor(cel::builtin::kSeconds, true), DurationAccessorFunction::WrapFunction( [](ValueFactory&, absl::Duration d) -> int64_t { return absl::ToInt64Seconds(d); }))); CEL_RETURN_IF_ERROR(registry->Register( - DurationAccessorFunction::CreateDescriptor(builtin::kMilliseconds, true), + DurationAccessorFunction::CreateDescriptor(cel::builtin::kMilliseconds, + true), DurationAccessorFunction::WrapFunction( [](ValueFactory&, absl::Duration d) -> int64_t { constexpr int64_t millis_per_second = 1000L; @@ -1566,23 +1557,11 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, options)); return registry->Register( - UnaryFunctionAdapter< - Handle, const Handle&>::CreateDescriptor(builtin::kType, - false), + UnaryFunctionAdapter, const Handle&>:: + CreateDescriptor(cel::builtin::kType, false), UnaryFunctionAdapter, const Handle&>::WrapFunction( [](ValueFactory& factory, const Handle& value) { - // TODO(issues/5): legacy types don't interop with type values - // from factory. This should simply be: - // - // return factory.CreateTypeValue(value->type()); - Arena* arena = - cel::extensions::ProtoMemoryManager::CastToProtoArena( - factory.memory_manager()); - CelValue legacy_value = - cel::interop_internal::ModernValueToLegacyValueOrDie(arena, - value); - return cel::interop_internal::CreateTypeValueFromView( - legacy_value.ObtainCelType().CelTypeOrDie().value()); + return factory.CreateTypeValue(value->type()); })); } From be02231d55ac13d520bc24e2644468ff823f5c3b Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 19 May 2023 15:41:37 -0700 Subject: [PATCH 294/303] Update `CloudBuild` and `Dockerfile` PiperOrigin-RevId: 533569135 --- Dockerfile | 52 ++++++++++++++++++++++++++++++++++++++----------- cloudbuild.yaml | 42 +++++++++++++++++---------------------- 2 files changed, 59 insertions(+), 35 deletions(-) diff --git a/Dockerfile b/Dockerfile index eeae61607..50282b5fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,47 @@ -FROM gcr.io/gcp-runtimes/ubuntu_20_0_4 +FROM gcc:9 -ENV DEBIAN_FRONTEND=noninteractive +# Install Bazel prerequesites and required tools. +# See https://docs.bazel.build/versions/master/install-ubuntu.html +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + git \ + libssl-dev \ + make \ + pkg-config \ + python3 \ + unzip \ + wget \ + zip \ + zlib1g-dev \ + default-jdk-headless \ + clang-11 && \ + apt-get clean -RUN rm -rf /var/lib/apt/lists/* \ - && apt-get update --fix-missing -qq \ - && apt-get install -qqy --no-install-recommends build-essential ca-certificates tzdata wget git default-jdk clang-12 lld-12 patch \ - && apt-get clean && rm -rf /var/lib/apt/lists/* +# Install Bazel. +# https://github.com/bazelbuild/bazel/releases +ARG BAZEL_VERSION="6.2.0" +ADD https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh /tmp/install_bazel.sh +RUN /bin/bash /tmp/install_bazel.sh && rm /tmp/install_bazel.sh -RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.5.0/bazelisk-linux-amd64 && chmod +x bazelisk-linux-amd64 && mv bazelisk-linux-amd64 /bin/bazel - -ENV CC=clang-12 -ENV CXX=clang++-12 +# When Bazel runs, it downloads some of its own implicit +# dependencies. The following command preloads these dependencies. +# Passing `--distdir=/bazel-distdir` to bazel allows it to use these +# dependencies. See +# https://docs.bazel.build/versions/master/guide.html#running-bazel-in-an-airgapped-environment +# for more information. +RUN cd /tmp && \ + git clone https://github.com/bazelbuild/bazel && \ + cd bazel && \ + git checkout ${BAZEL_VERSION} && \ + bazel build @additional_distfiles//:archives.tar && \ + mkdir /bazel-distdir && \ + tar xvf bazel-bin/external/additional_distfiles/archives.tar -C /bazel-distdir --strip-components=3 && \ + cd / && \ + rm -rf /tmp/* && \ + rm -rf /root/.cache/bazel RUN mkdir -p /workspace +RUN mkdir -p /bazel -ENTRYPOINT ["/bin/bazel"] +ENTRYPOINT ["/usr/local/bin/bazel"] diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 8c9398e91..6ec7b3fce 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,35 +1,29 @@ steps: -- name: 'gcr.io/cel-analysis/bazel:ubuntu_20_0_4' - entrypoint: bazel +- name: 'gcr.io/cel-analysis/gcc-9:latest' args: - - '--output_base=/bazel' + - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' - - '--test_output=errors' - '...' - id: bazel-test -- name: 'gcr.io/cel-analysis/bazel:ubuntu_20_0_4' - entrypoint: bazel - args: - - '--output_base=/bazel' - - 'test' - - '--config=asan' + - '--compilation_mode=fastbuild' - '--test_output=errors' - - '...' - id: bazel-asan -- name: 'gcr.io/cel-analysis/bazel:ubuntu_20_0_4' - entrypoint: bazel + - '--distdir=/bazel-distdir' + - '--show_timestamps' + id: gcc-9 + waitFor: ['-'] +- name: 'gcr.io/cel-analysis/gcc-9:latest' env: - - 'CC=gcc' - - 'CXX=g++' + - 'CC=clang-11' + - 'CXX=clang++-11' args: - - '--output_base=/bazel' + - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' - - '--test_output=errors' - '...' - id: bazel-gcc + - '--compilation_mode=fastbuild' + - '--test_output=errors' + - '--distdir=/bazel-distdir' + - '--show_timestamps' + id: clang-11 + waitFor: ['-'] timeout: 1h options: - machineType: 'N1_HIGHCPU_8' - volumes: - - name: bazel - path: /bazel + machineType: 'N1_HIGHCPU_32' From 9c481a442c32843d7747146d88d6679497bc8a01 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 22 May 2023 10:03:08 -0700 Subject: [PATCH 295/303] Fix GCC warnings related to friending self PiperOrigin-RevId: 534100665 --- base/values/list_value.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/base/values/list_value.h b/base/values/list_value.h index 9322351f9..6268990c3 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -177,7 +177,7 @@ class LegacyListValue final : public ListValue, public InlineData { MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; private: - friend class base_internal::ValueHandle; + friend class ValueHandle; friend class cel::ListValue; template friend struct AnyData; @@ -232,11 +232,9 @@ class AbstractListValue : public ListValue, private: friend class cel::ListValue; - friend class base_internal::LegacyListValue; - friend class base_internal::AbstractListValue; - friend internal::TypeInfo base_internal::GetListValueTypeId( - const ListValue& list_value); - friend class base_internal::ValueHandle; + friend class LegacyListValue; + friend internal::TypeInfo GetListValueTypeId(const ListValue& list_value); + friend class ValueHandle; // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; From bdfb23f0c3453274a128924fd5c99873c86d5c4d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 22 May 2023 10:03:46 -0700 Subject: [PATCH 296/303] Internal Change PiperOrigin-RevId: 534100904 --- base/ast_internal.h | 18 +++++++++--------- base/attribute.h | 2 +- base/function.h | 2 +- base/function_adapter.h | 2 +- base/function_descriptor.h | 2 +- base/function_result.h | 2 +- base/internal/function_adapter.h | 2 +- base/memory.cc | 4 ++-- base/type_manager.h | 2 +- base/type_registry.h | 2 +- base/values/list_value.h | 2 +- base/values/map_value_builder.h | 2 +- eval/compiler/constant_folding.cc | 14 +++++++------- eval/compiler/flat_expr_builder.cc | 4 ++-- eval/compiler/flat_expr_builder.h | 2 +- .../regex_precompilation_optimization_test.cc | 2 +- eval/eval/const_value_step.h | 2 +- eval/eval/container_access_step_test.cc | 4 ++-- eval/eval/create_list_step.cc | 4 ++-- eval/eval/create_struct_step.cc | 4 ++-- eval/eval/evaluator_core.cc | 2 +- eval/eval/evaluator_core.h | 2 +- eval/internal/errors.cc | 2 +- eval/internal/interop.h | 2 +- eval/public/builtin_func_registrar.cc | 4 ++-- eval/public/cel_type_registry.cc | 2 +- eval/public/cel_value.cc | 6 +++--- eval/public/cel_value.h | 6 +++--- eval/public/equality_function_registrar.cc | 2 +- .../public/equality_function_registrar_test.cc | 4 ++-- .../portable_cel_expr_builder_factory.cc | 2 +- eval/public/string_extension_func_registrar.h | 2 +- eval/public/structs/field_access_impl.cc | 2 +- eval/public/structs/field_access_impl.h | 2 +- eval/public/structs/legacy_type_adapter.h | 2 +- eval/public/structs/legacy_type_provider.h | 4 ++-- eval/tests/memory_safety_test.cc | 2 +- extensions/protobuf/struct_value.cc | 2 +- runtime/activation_interface.h | 2 +- runtime/standard/comparison_functions_test.cc | 2 +- 40 files changed, 66 insertions(+), 66 deletions(-) diff --git a/base/ast_internal.h b/base/ast_internal.h index 30b1a980d..72b9e5d16 100644 --- a/base/ast_internal.h +++ b/base/ast_internal.h @@ -14,7 +14,7 @@ // // Type definitions for internal AST representation. // CEL users should not directly depend on the definitions here. -// TODO(issues/5): move to base/internal +// TODO(uncreated-issue/31): move to base/internal #ifndef THIRD_PARTY_CEL_CPP_BASE_AST_INTERNAL_H_ #define THIRD_PARTY_CEL_CPP_BASE_AST_INTERNAL_H_ @@ -54,7 +54,7 @@ struct Bytes { // `true`, `null`. // // (-- -// TODO(issues/5): Extend or replace the constant with a canonical Value +// TODO(uncreated-issue/9): Extend or replace the constant with a canonical Value // message that can hold any constant object representation supplied or // produced at evaluation time. // --) @@ -302,7 +302,7 @@ class Select { // A call expression, including calls to predefined functions and operators. // // For example, `value == 10`, `size(map_value)`. -// (-- TODO(issues/5): Convert built-in globals to instance methods --) +// (-- TODO(uncreated-issue/11): Convert built-in globals to instance methods --) class Call { public: Call() = default; @@ -349,7 +349,7 @@ class Call { // Lists may either be homogenous, e.g. `[1, 2, 3]`, or heterogeneous, e.g. // `dyn([1, 'hello', 2.0])` // (-- -// TODO(issues/5): Determine how to disable heterogeneous types as a feature +// TODO(uncreated-issue/12): Determine how to disable heterogeneous types as a feature // of type-checking rather than through the language construct 'dyn'. // --) class CreateList { @@ -527,7 +527,7 @@ class CreateStruct { // ``` // // (-- -// TODO(issues/5): ensure comprehensions work equally well on maps and +// TODO(uncreated-issue/13): ensure comprehensions work equally well on maps and // messages. // --) class Comprehension { @@ -942,7 +942,7 @@ class SourceInfo { // `id` the `line_offsets[i] < id_positions[id] < line_offsets[i+1]`. The // column may be derivd from `id_positions[id] - line_offsets[i]`. // - // TODO(issues/5): clarify this documentation + // TODO(uncreated-issue/14): clarify this documentation std::vector line_offsets_; // A map from the parse node id (e.g. `Expr.id`) to the code point offset @@ -1019,7 +1019,7 @@ enum class PrimitiveType { // Well-known protobuf types treated with first-class support in CEL. // -// TODO(issues/5): represent well-known via abstract types (or however) +// TODO(uncreated-issue/15): represent well-known via abstract types (or however) // they will be named. enum class WellKnownType { // Unspecified type. @@ -1157,7 +1157,7 @@ class FunctionType { // Application defined abstract type. // -// TODO(issues/5): decide on final naming for this. +// TODO(uncreated-issue/15): decide on final naming for this. class AbstractType { public: AbstractType() = default; @@ -1260,7 +1260,7 @@ using TypeKind = // Analogous to google::api::expr::v1alpha1::Type. // Represents a CEL type. // -// TODO(issues/5): align with value.proto +// TODO(uncreated-issue/15): align with value.proto class Type { public: Type() = default; diff --git a/base/attribute.h b/base/attribute.h index f972c4bac..69b464fbd 100644 --- a/base/attribute.h +++ b/base/attribute.h @@ -185,7 +185,7 @@ class Attribute final { : impl_(std::make_shared(std::move(variable_name), std::move(qualifier_path))) {} - // TODO(issues/5): remove this constructor as it pulls in proto deps + // TODO(uncreated-issue/16): remove this constructor as it pulls in proto deps Attribute(const google::api::expr::v1alpha1::Expr& variable, std::vector qualifier_path); diff --git a/base/function.h b/base/function.h index 6b71e18a6..7c3b51ffd 100644 --- a/base/function.h +++ b/base/function.h @@ -43,7 +43,7 @@ class Function { // extension function. cel::ValueFactory& value_factory() const { return value_factory_; } - // TODO(issues/5): Add accessors for getting attribute stack and mutable + // TODO(uncreated-issue/24): Add accessors for getting attribute stack and mutable // value stack. private: cel::ValueFactory& value_factory_; diff --git a/base/function_adapter.h b/base/function_adapter.h index d80b436e8..95ca93d84 100644 --- a/base/function_adapter.h +++ b/base/function_adapter.h @@ -14,7 +14,7 @@ // // Definitions for template helpers to wrap C++ functions as CEL extension // function implementations. -// TODO(issues/5): Add generalized version in addition to the common cases +// TODO(uncreated-issue/20): Add generalized version in addition to the common cases // of unary/binary functions. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ diff --git a/base/function_descriptor.h b/base/function_descriptor.h index 499ad9a85..d2a057b9b 100644 --- a/base/function_descriptor.h +++ b/base/function_descriptor.h @@ -42,7 +42,7 @@ class FunctionDescriptor final { // The argmument types the function accepts. // - // TODO(issues/5): make this kinds + // TODO(uncreated-issue/17): make this kinds const std::vector& types() const { return impl_->types; } // if true (strict, default), error or unknown arguments are propagated diff --git a/base/function_result.h b/base/function_result.h index fafab8899..977ceeb90 100644 --- a/base/function_result.h +++ b/base/function_result.h @@ -50,7 +50,7 @@ class FunctionResult final { return descriptor() == other.descriptor(); } - // TODO(issues/5): re-implement argument capture + // TODO(uncreated-issue/5): re-implement argument capture private: FunctionDescriptor descriptor_; diff --git a/base/internal/function_adapter.h b/base/internal/function_adapter.h index bf800a9bb..fc0624109 100644 --- a/base/internal/function_adapter.h +++ b/base/internal/function_adapter.h @@ -52,7 +52,7 @@ struct UnhandledType : std::false_type {}; // Adapts the type param Type to the appropriate Kind. // A static assertion fails if the provided type does not map to a cel::Value // kind. -// TODO(issues/5): Add support for remaining kinds. +// TODO(uncreated-issue/20): Add support for remaining kinds. template constexpr Kind AdaptedKind() { static_assert(UnhandledType::value, diff --git a/base/memory.cc b/base/memory.cc index 82992965d..dce96bc24 100644 --- a/base/memory.cc +++ b/base/memory.cc @@ -146,7 +146,7 @@ void ArenaBlockFree(void* pointer, size_t size) { #else static_cast(size); if (ABSL_PREDICT_FALSE(!VirtualFree(pointer, 0, MEM_RELEASE))) { - // TODO(issues/5): print the error + // TODO(uncreated-issue/8): print the error std::abort(); } #endif @@ -220,7 +220,7 @@ class DefaultArenaMemoryManager final : public ArenaMemoryManager { absl::Mutex mutex_; std::vector blocks_ ABSL_GUARDED_BY(mutex_); std::vector> owned_ ABSL_GUARDED_BY(mutex_); - // TODO(issues/5): we could use a priority queue to keep track of any + // TODO(uncreated-issue/8): we could use a priority queue to keep track of any // unallocated space at the end blocks. }; diff --git a/base/type_manager.h b/base/type_manager.h index e6e975f33..5a7378de2 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -31,7 +31,7 @@ namespace cel { // the instantiation of type implementations, loading of type implementations, // and registering type implementations. // -// TODO(issues/5): more comments after solidifying role +// TODO(uncreated-issue/8): more comments after solidifying role class TypeManager final { public: TypeManager(TypeFactory& type_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, diff --git a/base/type_registry.h b/base/type_registry.h index 3f5e21333..9baccb367 100644 --- a/base/type_registry.h +++ b/base/type_registry.h @@ -19,7 +19,7 @@ namespace cel { -// TODO(issues/5): define interface and consolidate with CelTypeRegistry +// TODO(uncreated-issue/8): define interface and consolidate with CelTypeRegistry class TypeRegistry : public TypeProvider {}; } // namespace cel diff --git a/base/values/list_value.h b/base/values/list_value.h index 6268990c3..732c2e268 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -61,7 +61,7 @@ class ListValue : public Value { return static_cast(value); } - // TODO(issues/5): implement iterators so we can have cheap concated lists + // TODO(uncreated-issue/10): implement iterators so we can have cheap concated lists Handle type() const; diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index 79c06a7f7..a9ee14698 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -94,7 +94,7 @@ class MapValueBuilder; namespace base_internal { -// TODO(issues/5): add checks ensuring keys and values match their expected +// TODO(uncreated-issue/21): add checks ensuring keys and values match their expected // types for all operations. template diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index db7b6f7e0..5b629254f 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -73,7 +73,7 @@ Handle CreateLegacyListBackedHandle( } struct MakeConstantArenaSafeVisitor { - // TODO(issues/5): make the AST to runtime Value conversion work with + // TODO(uncreated-issue/33): make the AST to runtime Value conversion work with // non-arena based cel::MemoryManager. google::protobuf::Arena* arena; @@ -165,7 +165,7 @@ class ConstantFoldingTransform { } bool operator()(const Ident& ident) { - // TODO(issues/5): this could be updated to use the rewrite visitor + // TODO(uncreated-issue/34): this could be updated to use the rewrite visitor // to make changes in-place instead of manually copy. This would avoid // having to understand how to copy all of the information in the original // AST. @@ -281,7 +281,7 @@ class ConstantFoldingTransform { bool all_constant = true; for (int i = 0; i < list_size; i++) { auto& element = list_expr.mutable_elements().emplace_back(); - // TODO(issues/5): Add support for CEL optional. + // TODO(uncreated-issue/34): Add support for CEL optional. all_constant = transform_.Transform(expr_.list_expr().elements()[i], element) && all_constant; @@ -292,7 +292,7 @@ class ConstantFoldingTransform { } if (list_size == 0) { - // TODO(issues/5): need a more robust fix to support generic + // TODO(uncreated-issue/35): need a more robust fix to support generic // comprehensions, but this will allow comprehension list append // optimization to work to prevent quadratic memory consumption for // map/filter. @@ -320,7 +320,7 @@ class ConstantFoldingTransform { auto& new_entry = struct_expr.mutable_entries().emplace_back(); new_entry.set_id(entry.id()); struct { - // TODO(issues/5): Add support for CEL optional. + // TODO(uncreated-issue/34): Add support for CEL optional. ConstantFoldingTransform& transform; const CreateStruct::Entry& entry; CreateStruct::Entry& new_entry; @@ -372,7 +372,7 @@ class ConstantFoldingTransform { // Owns constant values created during folding Arena* arena_; - // TODO(issues/5): make this support generic memory manager and value + // TODO(uncreated-issue/33): make this support generic memory manager and value // factory. This is only safe for interop where we know an arena is always // available. extensions::ProtoMemoryManager memory_manager_; @@ -433,7 +433,7 @@ absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, } IsConst operator()(const CreateList& create_list) { if (create_list.elements().empty()) { - // TODO(issues/5): Don't fold for empty list to allow comprehension + // TODO(uncreated-issue/35): Don't fold for empty list to allow comprehension // list append optimization. return IsConst::kNonConst; } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 490193d94..a0de17425 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -267,7 +267,7 @@ class FlatExprVisitor : public cel::ast::internal::AstVisitor { if (!progress_status_.ok()) { return; } - // TODO(issues/5): this will be generalized later. + // TODO(uncreated-issue/27): this will be generalized later. if (program_optimizers_.empty()) { return; } @@ -1177,7 +1177,7 @@ FlatExprBuilder::CreateExpression(const CheckedExpr* checked_expr) const { return CreateExpression(checked_expr, /*warnings=*/nullptr); } -// TODO(issues/5): move ast conversion to client responsibility and +// TODO(uncreated-issue/31): move ast conversion to client responsibility and // update pre-processing steps to work without mutating the input AST. absl::StatusOr> FlatExprBuilder::CreateExpressionImpl( diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 0324ffcbf..c0f6a69ee 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -44,7 +44,7 @@ class FlatExprBuilder : public CelExpressionBuilder { // Toggle constant folding optimization. By default it is not enabled. // The provided arena is used to hold the generated constants. - // TODO(issues/5): default enable the updated version then deprecate this + // TODO(uncreated-issue/27): default enable the updated version then deprecate this // function. void set_constant_folding(bool enabled, google::protobuf::Arena* arena) { constant_folding_ = enabled; diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index 973c28ecd..da1a0b01c 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -166,7 +166,7 @@ TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { public: RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { - // TODO(issues/5): This applies to either version of const folding. + // TODO(uncreated-issue/27): This applies to either version of const folding. // Update when default is changed to new version. builder_.set_constant_folding(true, &arena_); } diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index 15ae6408f..4fdc3cc9f 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -12,7 +12,7 @@ namespace google::api::expr::runtime { -// TODO(issues/5): move this somewhere else +// TODO(uncreated-issue/29): move this somewhere else cel::Handle ConvertConstant( const cel::ast::internal::Constant& const_expr); diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 9ec7fd6f9..6f88ee2d5 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -425,7 +425,7 @@ TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndexNotAnInt) { // treat uint as uint before trying coercion to signed int. TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsUint) { - // TODO(issues/5): Map creation should error here instead of permitting + // TODO(uncreated-issue/4): Map creation should error here instead of permitting // mixed key types with equivalent values. ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( @@ -554,7 +554,7 @@ TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsUint) { - // TODO(issues/5): Map creation should error here instead of permitting + // TODO(uncreated-issue/4): Map creation should error here instead of permitting // mixed key types with equivalent values. ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index e0124d3dd..81ffe9bb0 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -76,13 +76,13 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { frame->memory_manager()); if (immutable_) { - // TODO(issues/5): switch to new cel::ListValue in phase 2 + // TODO(uncreated-issue/23): switch to new cel::ListValue in phase 2 result = CreateLegacyListValue(google::protobuf::Arena::Create( arena, ModernValueToLegacyValueOrDie(frame->memory_manager(), args))); } else { - // TODO(issues/5): switch to new cel::ListValue in phase 2 + // TODO(uncreated-issue/23): switch to new cel::ListValue in phase 2 result = CreateLegacyListValue(google::protobuf::Arena::Create( arena, ModernValueToLegacyValueOrDie(frame->memory_manager(), args))); } diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 70d3203cd..336f6fc29 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -79,7 +79,7 @@ absl::StatusOr> CreateStructStepForMessage::DoEvaluate( } } - // TODO(issues/5): switch to new cel::StructValue in phase 2 + // TODO(uncreated-issue/32): switch to new cel::StructValue in phase 2 CEL_ASSIGN_OR_RETURN(MessageWrapper::Builder instance, type_adapter_->NewInstance(frame->memory_manager())); @@ -131,7 +131,7 @@ absl::StatusOr> CreateStructStepForMap::DoEvaluate( } } - // TODO(issues/5): switch to new cel::MapValue in phase 2 + // TODO(uncreated-issue/32): switch to new cel::MapValue in phase 2 auto* map_builder = google::protobuf::Arena::Create( cel::extensions::ProtoMemoryManager::CastToProtoArena( frame->memory_manager())); diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 9e0683b5e..8b9012b6f 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -31,7 +31,7 @@ absl::Status InvalidIterationStateError() { } // namespace -// TODO(issues/5): cel::TypeFactory and family are setup here assuming legacy +// TODO(uncreated-issue/28): cel::TypeFactory and family are setup here assuming legacy // value interop. Later, these will need to be configurable by clients. CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( size_t value_stack_size, google::protobuf::Arena* arena) diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 2340b66df..54679ed22 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -121,7 +121,7 @@ class CelExpressionFlatEvaluationState : public CelEvaluationState { cel::ValueFactory& value_factory() { return value_factory_; } private: - // TODO(issues/5): State owns a ProtoMemoryManager to adapt from the client + // TODO(uncreated-issue/1): State owns a ProtoMemoryManager to adapt from the client // provided arena. In the future, clients will have to maintain the particular // manager they want to use for evaluation. cel::extensions::ProtoMemoryManager memory_manager_; diff --git a/eval/internal/errors.cc b/eval/internal/errors.cc index 60f170457..73713a529 100644 --- a/eval/internal/errors.cc +++ b/eval/internal/errors.cc @@ -88,7 +88,7 @@ const absl::Status* CreateMissingAttributeError( const absl::Status* CreateMissingAttributeError( cel::MemoryManager& manager, absl::string_view missing_attribute_path) { - // TODO(issues/5): assume arena-style allocator while migrating + // TODO(uncreated-issue/1): assume arena-style allocator while migrating // to new value type. return CreateMissingAttributeError( extensions::ProtoMemoryManager::CastToProtoArena(manager), diff --git a/eval/internal/interop.h b/eval/internal/interop.h index 3791b1895..5ef6dbd4b 100644 --- a/eval/internal/interop.h +++ b/eval/internal/interop.h @@ -122,7 +122,7 @@ Handle CreateDurationValue(absl::Duration value, bool unchecked = false); // Create a modern timestamp value, without validation. Should only be used // during interoperation. -// TODO(issues/5): Consider adding a check that the timestamp is in the +// TODO(uncreated-issue/39): Consider adding a check that the timestamp is in the // supported range for CEL. Handle CreateTimestampValue(absl::Time value); diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 1c0955e0e..04b3ee6d1 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -705,7 +705,7 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, return absl::OkStatus(); } -// TODO(issues/5): after refactors for the new value type are done, move this +// TODO(uncreated-issue/36): after refactors for the new value type are done, move this // to a separate build target to enable subset environments to not depend on // RE2. absl::Status RegisterRegexFunctions(CelFunctionRegistry* registry, @@ -1426,7 +1426,7 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( absl::Status RegisterUncheckedTimeArithmeticFunctions( CelFunctionRegistry* registry) { - // TODO(issues/5): deprecate unchecked time math functions when clients no + // TODO(uncreated-issue/37): deprecate unchecked time math functions when clients no // longer depend on them. CEL_RETURN_IF_ERROR(registry->Register( BinaryFunctionAdapter, absl::Time, absl::Duration>:: diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index caff983df..5fedefe5a 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -135,7 +135,7 @@ class ResolveableEnumType final : public cel::EnumType { } std::string name_; - // TODO(issues/5): this could be indexed by name and/or number if strong + // TODO(uncreated-issue/42): this could be indexed by name and/or number if strong // enum typing is needed at runtime. std::vector enumerators_; }; diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 2086c47b8..645cd2124 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -227,7 +227,7 @@ const std::string CelValue::DebugString() const { CelValue CreateErrorValue(cel::MemoryManager& manager, absl::string_view message, absl::StatusCode error_code) { - // TODO(issues/5): assume arena-style allocator while migrating to new + // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new // value type. Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(manager); return CreateErrorValue(arena, message, error_code); @@ -235,7 +235,7 @@ CelValue CreateErrorValue(cel::MemoryManager& manager, CelValue CreateErrorValue(cel::MemoryManager& manager, const absl::Status& status) { - // TODO(issues/5): assume arena-style allocator while migrating to new + // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new // value type. Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(manager); return CreateErrorValue(arena, status); @@ -302,7 +302,7 @@ CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, CelValue CreateMissingAttributeError(cel::MemoryManager& manager, absl::string_view missing_attribute_path) { - // TODO(issues/5): assume arena-style allocator while migrating + // TODO(uncreated-issue/1): assume arena-style allocator while migrating // to new value type. return CelValue::CreateError( interop::CreateMissingAttributeError(manager, missing_attribute_path)); diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 6bb60beb6..9aeac4dfe 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -394,7 +394,7 @@ class CelValue { // Invokes op() with the active value, and returns the result. // All overloads of op() must have the same return type. - // TODO(issues/5): Move to CelProtoWrapper to retain the assumed + // TODO(uncreated-issue/2): Move to CelProtoWrapper to retain the assumed // google::protobuf::Message variant version behavior for client code. template ReturnType Visit(Op&& op) const { @@ -414,7 +414,7 @@ class CelValue { // Factory for message wrapper. This should only be used by internal // libraries. - // TODO(issues/5): exposed for testing while wiring adapter APIs. Should + // TODO(uncreated-issue/2): exposed for testing while wiring adapter APIs. Should // make private visibility after refactors are done. static CelValue CreateMessageWrapper(MessageWrapper value) { CheckNullPointer(value.message_ptr(), Type::kMessage); @@ -444,7 +444,7 @@ class CelValue { // Specialization for MessageWrapper to support legacy behavior while // migrating off hard dependency on google::protobuf::Message. - // TODO(issues/5): Move to CelProtoWrapper. + // TODO(uncreated-issue/2): Move to CelProtoWrapper. template struct AssignerOp< T, std::enable_if_t>> { diff --git a/eval/public/equality_function_registrar.cc b/eval/public/equality_function_registrar.cc index a61497d4b..3f2f760c8 100644 --- a/eval/public/equality_function_registrar.cc +++ b/eval/public/equality_function_registrar.cc @@ -426,7 +426,7 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { return *lhs == *rhs; } - // TODO(issues/5): It's currently possible for the interpreter to create a + // TODO(uncreated-issue/6): It's currently possible for the interpreter to create a // map containing an Error. Return no matching overload to propagate an error // instead of a false result. if (v1.IsError() || v1.IsUnknownSet() || v2.IsError() || v2.IsUnknownSet()) { diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 43111c748..eba219435 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -656,7 +656,7 @@ INSTANTIATE_TEST_SUITE_P( // This should fail before getting to the equal operator. {"no_such_identifier == 1", EqualityTestCase::ErrorKind::kMissingIdentifier}, - // TODO(issues/5): The C++ evaluator allows creating maps + // TODO(uncreated-issue/6): The C++ evaluator allows creating maps // with error values. Propagate an error instead of a false // result. {"{1: no_such_identifier} == {1: 1}", @@ -684,7 +684,7 @@ INSTANTIATE_TEST_SUITE_P( // This should fail before getting to the equal operator. {"no_such_identifier != 1", EqualityTestCase::ErrorKind::kMissingIdentifier}, - // TODO(issues/5): The C++ evaluator allows creating maps + // TODO(uncreated-issue/6): The C++ evaluator allows creating maps // with error values. Propagate an error instead of a false // result. {"{1: no_such_identifier} != {1: 1}", diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index d920a2125..50e73cd35 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -47,7 +47,7 @@ std::unique_ptr CreatePortableExprBuilder( (options.enable_qualified_identifier_rewrites) ? ReferenceResolverOption::kAlways : ReferenceResolverOption::kCheckedOnly)); - // TODO(issues/5): These need to be abstracted to avoid bringing in too + // TODO(uncreated-issue/27): These need to be abstracted to avoid bringing in too // many build dependencies by default. builder->set_enable_comprehension_vulnerability_check( options.enable_comprehension_vulnerability_check); diff --git a/eval/public/string_extension_func_registrar.h b/eval/public/string_extension_func_registrar.h index bebe77a80..9772092e1 100644 --- a/eval/public/string_extension_func_registrar.h +++ b/eval/public/string_extension_func_registrar.h @@ -21,7 +21,7 @@ namespace google::api::expr::runtime { // Register string related widely used extension functions. -// TODO(issues/5): Move String extension function to +// TODO(uncreated-issue/22): Move String extension function to // extensions absl::Status RegisterStringExtensionFunctions( CelFunctionRegistry* registry, diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 7cc64fadb..d0766c85f 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -538,7 +538,7 @@ bool MergeFromWithSerializeFallback(const google::protobuf::Message& value, field.MergeFrom(value); return true; } - // TODO(issues/5): this indicates means we're mixing dynamic messages with + // TODO(uncreated-issue/26): this indicates means we're mixing dynamic messages with // generated messages. This is expected for WKTs where CEL explicitly requires // wire format compatibility, but this may not be the expected behavior for // other types. diff --git a/eval/public/structs/field_access_impl.h b/eval/public/structs/field_access_impl.h index 4e2caca64..78e22e5ba 100644 --- a/eval/public/structs/field_access_impl.h +++ b/eval/public/structs/field_access_impl.h @@ -49,7 +49,7 @@ absl::StatusOr CreateValueFromRepeatedField( // desc Descriptor of the field to access. // value_ref pointer to map value. // arena Arena object to allocate result on, if needed. -// TODO(issues/5): This should be inlined into the FieldBackedMap +// TODO(uncreated-issue/7): This should be inlined into the FieldBackedMap // implementation. absl::StatusOr CreateValueFromMapValue( const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index a21e6c795..e7761f870 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -35,7 +35,7 @@ class LegacyTypeMutationApis { virtual ~LegacyTypeMutationApis() = default; // Return whether the type defines the given field. - // TODO(issues/5): This is only used to eagerly fail during the planning + // TODO(uncreated-issue/3): This is only used to eagerly fail during the planning // phase. Check if it's safe to remove this behavior and fail at runtime. virtual bool DefinesField(absl::string_view field_name) const = 0; diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h index a563c73a0..eea5d44b3 100644 --- a/eval/public/structs/legacy_type_provider.h +++ b/eval/public/structs/legacy_type_provider.h @@ -34,7 +34,7 @@ class LegacyTypeProvider : public cel::TypeProvider { // // Returned non-null pointers from the adapter implemententation must remain // valid as long as the type provider. - // TODO(issues/5): add alternative for new type system. + // TODO(uncreated-issue/3): add alternative for new type system. virtual absl::optional ProvideLegacyType( absl::string_view name) const = 0; @@ -58,7 +58,7 @@ class LegacyTypeProvider : public cel::TypeProvider { // nullopt values are interpreted as not present. // // Returned non-null pointers must remain valid as long as the type provider. - // TODO(issues/5): Move protobuf-Any API from top level + // TODO(uncreated-issue/19): Move protobuf-Any API from top level // [Legacy]TypeProviders. virtual absl::optional ProvideLegacyAnyPackingApis( diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index becd93fd1..fa1585476 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -195,7 +195,7 @@ TEST_P(EvaluatorMemorySafetyTest, NoAstDependency) { EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); } -// TODO(issues/5): make expression plan memory safe after builder is freed. +// TODO(uncreated-issue/25): make expression plan memory safe after builder is freed. // TEST_P(EvaluatorMemorySafetyTest, NoBuilderDependency) INSTANTIATE_TEST_SUITE_P( diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index da0949afd..7d952c263 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// TODO(issues/5): get test coverage closer to 100% before using +// TODO(uncreated-issue/30): get test coverage closer to 100% before using #include "extensions/protobuf/struct_value.h" diff --git a/runtime/activation_interface.h b/runtime/activation_interface.h index 266f4494d..e5798d754 100644 --- a/runtime/activation_interface.h +++ b/runtime/activation_interface.h @@ -31,7 +31,7 @@ namespace cel { // // Clients should prefer to use one of the concrete implementations provided by // the CEL library rather than implementing this interface directly. -// TODO(issues/5): After finalizing, make this public and add instructions +// TODO(uncreated-issue/40): After finalizing, make this public and add instructions // for clients to migrate. class ActivationInterface { public: diff --git a/runtime/standard/comparison_functions_test.cc b/runtime/standard/comparison_functions_test.cc index cd85fdeec..062d693db 100644 --- a/runtime/standard/comparison_functions_test.cc +++ b/runtime/standard/comparison_functions_test.cc @@ -75,7 +75,7 @@ TEST(RegisterComparisonFunctionsTest, GreaterThanOrEqualDefined) { } } -// TODO(issues/5): move functional tests from wrapper library after top-level +// TODO(uncreated-issue/41): move functional tests from wrapper library after top-level // APIs are available for planning and running an expression. } // namespace From f7210edb158ca78c32816e05d78abd0da4fdecec Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 22 May 2023 10:08:57 -0700 Subject: [PATCH 297/303] Skip benchmarks during `CloudBuild` testing PiperOrigin-RevId: 534102578 --- cloudbuild.yaml | 2 ++ eval/tests/BUILD | 4 ++++ parser/BUILD | 1 + 3 files changed, 7 insertions(+) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 6ec7b3fce..de514b9e8 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -8,6 +8,7 @@ steps: - '--test_output=errors' - '--distdir=/bazel-distdir' - '--show_timestamps' + - '--test_tag_filters=-benchmark' id: gcc-9 waitFor: ['-'] - name: 'gcr.io/cel-analysis/gcc-9:latest' @@ -22,6 +23,7 @@ steps: - '--test_output=errors' - '--distdir=/bazel-distdir' - '--show_timestamps' + - '--test_tag_filters=-benchmark' id: clang-11 waitFor: ['-'] timeout: 1h diff --git a/eval/tests/BUILD b/eval/tests/BUILD index cbe3b246a..626c67a92 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -45,6 +45,7 @@ cc_library( cc_test( name = "benchmark_test", size = "small", + tags = ["benchmark"], deps = [ ":benchmark_testlib", "@com_github_google_benchmark//:benchmark", @@ -59,6 +60,7 @@ cc_test( name = "const_folding_benchmark_test", size = "small", args = ["--enable_optimizations"], + tags = ["benchmark"], deps = [ ":benchmark_testlib", "@com_github_google_benchmark//:benchmark", @@ -72,6 +74,7 @@ cc_test( srcs = [ "allocation_benchmark_test.cc", ], + tags = ["benchmark"], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -130,6 +133,7 @@ cc_test( srcs = [ "expression_builder_benchmark_test.cc", ], + tags = ["benchmark"], deps = [ ":request_context_cc_proto", "//eval/public:builtin_func_registrar", diff --git a/parser/BUILD b/parser/BUILD index 95b073921..f7b7b51fe 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -112,6 +112,7 @@ cc_library( cc_test( name = "parser_test", srcs = ["parser_test.cc"], + tags = ["benchmark"], deps = [ ":options", ":parser", From f86575baa3604a7e9243eeec7d09d7b2d92062ab Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 22 May 2023 10:20:56 -0700 Subject: [PATCH 298/303] Allow dynamically building `ListValue`, `MapValue`, and `StructValue` PiperOrigin-RevId: 534106408 --- base/BUILD | 1 + base/handle.h | 50 +++++++++ base/types/list_type.cc | 39 +++++++ base/types/list_type.h | 13 ++- base/types/map_type.cc | 53 +++++++++ base/types/map_type.h | 12 ++- base/types/struct_type.cc | 22 ++++ base/types/struct_type.h | 20 +++- base/values/list_value_builder_test.cc | 41 +++++++ base/values/map_value_builder_test.cc | 137 ++++++++++++++++++++++++ base/values/struct_value.h | 2 + base/values/struct_value_builder.h | 40 +++++++ extensions/protobuf/BUILD | 2 +- extensions/protobuf/struct_type_test.cc | 13 +++ 14 files changed, 441 insertions(+), 4 deletions(-) create mode 100644 base/values/struct_value_builder.h diff --git a/base/BUILD b/base/BUILD index f3065cb83..6c5211d69 100644 --- a/base/BUILD +++ b/base/BUILD @@ -46,6 +46,7 @@ cc_library( deps = [ "//base/internal:data", "//base/internal:handle", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", ], diff --git a/base/handle.h b/base/handle.h index fdf990b70..45d67ddbe 100644 --- a/base/handle.h +++ b/base/handle.h @@ -15,10 +15,14 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ #define THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ +#include +#include #include #include #include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/macros.h" #include "absl/log/absl_check.h" #include "base/internal/data.h" #include "base/internal/handle.h" // IWYU pragma: export @@ -302,6 +306,52 @@ struct HandleFactory { template static std::enable_if_t, Handle> Make( MemoryManager& memory_manager, Args&&... args); + + // Constructs a handle from `*this` for classes which extend `Type` and + // `Value`. + template + static Handle FromThis(F& self) { + if constexpr (std::is_base_of_v) { + // If F is derived from InlineData, we don't need to perform runtime + // checks. This is selected at compile time. + return Handle(*absl::bit_cast*>( + static_cast(std::addressof(self)))); + } + if constexpr (std::is_base_of_v) { + // If F is derived from HeapData, we don't need to perform runtime checks. + // This is selected at compile time. + if (Metadata::IsReferenceCounted(self)) { + Metadata::Ref(self); + return Handle(kInPlaceReferenceCounted, self); + } + // Must be arena allocated. + ABSL_ASSERT(Metadata::IsArenaAllocated(self)); + return Handle(kInPlaceArenaAllocated, self); + } + // Perform runtime checks, F is not derived from InlineData or HeapData so + // it must be a abstract base class. + if (Metadata::IsStoredInline(self)) { + return Handle(*absl::bit_cast*>( + static_cast(std::addressof(self)))); + } + // Must be heap allocated. + if (Metadata::IsReferenceCounted(self)) { + Metadata::Ref(self); + return Handle(kInPlaceReferenceCounted, self); + } + // Must be arena allocated. + ABSL_ASSERT(Metadata::IsArenaAllocated(self)); + return Handle(kInPlaceArenaAllocated, self); + } +}; + +template +struct EnableHandleFromThis { + protected: + Handle handle_from_this() const { + return HandleFactory::FromThis( + const_cast(*reinterpret_cast(this))); + } }; } // namespace cel::base_internal diff --git a/base/types/list_type.cc b/base/types/list_type.cc index 60a1d5e0a..777b46333 100644 --- a/base/types/list_type.cc +++ b/base/types/list_type.cc @@ -22,7 +22,12 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/internal/data.h" +#include "base/kind.h" +#include "base/memory.h" #include "base/types/dyn_type.h" +#include "base/value_factory.h" +#include "base/values/list_value.h" +#include "base/values/list_value_builder.h" namespace cel { @@ -48,6 +53,40 @@ const Handle& ListType::element() const { return static_cast(*this).element(); } +absl::StatusOr> ListType::NewValueBuilder( + ValueFactory& value_factory) const { + switch (element()->kind()) { + case Kind::kBool: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + case Kind::kInt: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + case Kind::kUint: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + case Kind::kDouble: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + case Kind::kDuration: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + case Kind::kTimestamp: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + default: + return MakeUnique>( + value_factory.memory_manager(), base_internal::kComposedListType, + value_factory, handle_from_this()); + } +} + namespace base_internal { const Handle& LegacyListType::element() const { diff --git a/base/types/list_type.h b/base/types/list_type.h index 530898074..d60cd516b 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -19,9 +19,12 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/absl_check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/memory.h" @@ -31,10 +34,14 @@ namespace cel { class MemoryManager; class ListValue; +class ValueFactory; +class ListValueBuilderInterface; // ListType represents a list type. A list is a sequential container where each // element is the same type. -class ListType : public Type { +class ListType + : public Type, + public base_internal::EnableHandleFromThis { public: static constexpr Kind kKind = Kind::kList; @@ -56,6 +63,10 @@ class ListType : public Type { return static_cast(type); } + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: friend class Type; friend class MemoryManager; diff --git a/base/types/map_type.cc b/base/types/map_type.cc index 5fdf9e2a7..12c2b3d23 100644 --- a/base/types/map_type.cc +++ b/base/types/map_type.cc @@ -22,7 +22,12 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/internal/data.h" +#include "base/kind.h" +#include "base/memory.h" #include "base/types/dyn_type.h" +#include "base/value_factory.h" +#include "base/values/map_value.h" +#include "base/values/map_value_builder.h" namespace cel { @@ -56,6 +61,54 @@ const Handle& MapType::value() const { return static_cast(*this).value(); } +namespace { + +template +absl::StatusOr> NewMapValueBuilderFor( + ValueFactory& value_factory, Handle type) { + switch (type->value()->kind()) { + case Kind::kBool: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + case Kind::kInt: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + case Kind::kUint: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + case Kind::kDouble: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + case Kind::kDuration: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + case Kind::kTimestamp: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + default: + return MakeUnique>( + value_factory.memory_manager(), value_factory, std::move(type)); + } +} + +} // namespace + +absl::StatusOr> MapType::NewValueBuilder( + ValueFactory& value_factory) const { + switch (key()->kind()) { + case Kind::kBool: + return NewMapValueBuilderFor(value_factory, + handle_from_this()); + case Kind::kInt: + return NewMapValueBuilderFor(value_factory, handle_from_this()); + case Kind::kUint: + return NewMapValueBuilderFor(value_factory, + handle_from_this()); + default: + return NewMapValueBuilderFor(value_factory, handle_from_this()); + } +} + namespace base_internal { const Handle& LegacyMapType::key() const { diff --git a/base/types/map_type.h b/base/types/map_type.h index 083f6f82d..14aba4557 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -19,9 +19,12 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/absl_check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/memory.h" @@ -32,10 +35,13 @@ namespace cel { class MemoryManager; class TypeFactory; class MapValue; +class ValueFactory; +class MapValueBuilderInterface; // MapType represents a map type. A map is container of key and value pairs // where each key appears at most once. -class MapType : public Type { +class MapType : public Type, + public base_internal::EnableHandleFromThis { public: static constexpr Kind kKind = Kind::kMap; @@ -60,6 +66,10 @@ class MapType : public Type { // Returns the type of the values in the map. const Handle& value() const; + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: friend class Type; friend class MemoryManager; diff --git a/base/types/struct_type.cc b/base/types/struct_type.cc index a06984d00..cca5d956e 100644 --- a/base/types/struct_type.cc +++ b/base/types/struct_type.cc @@ -23,6 +23,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "base/values/struct_value_builder.h" #include "internal/overloaded.h" #include "internal/status_macros.h" @@ -105,6 +106,11 @@ StructType::NewFieldIterator(MemoryManager& memory_manager) const { return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(NewFieldIterator, memory_manager); } +absl::StatusOr> +StructType::NewValueBuilder(ValueFactory& value_factory) const { + return CEL_INTERNAL_STRUCT_TYPE_DISPATCH(NewValueBuilder, value_factory); +} + #undef CEL_INTERNAL_STRUCT_TYPE_DISPATCH struct StructType::FindFieldVisitor final { @@ -182,6 +188,14 @@ LegacyStructType::FindFieldByNumber(TypeManager& type_manager, "Legacy struct type does not support type introspection"); } +absl::StatusOr> +LegacyStructType::NewValueBuilder(ValueFactory& value_factory) const { + return absl::UnimplementedError( + "StructType::NewValueBuilder is unimplemented. Perhaps the value library " + "is not linked into your binary or StructType::NewValueBuilder was not " + "overridden?"); +} + AbstractStructType::AbstractStructType() : StructType(), base_internal::HeapData(kKind) { // Ensure `Type*` and `base_internal::HeapData*` are not thunked. @@ -190,6 +204,14 @@ AbstractStructType::AbstractStructType() reinterpret_cast(static_cast(this))); } +absl::StatusOr> +AbstractStructType::NewValueBuilder(ValueFactory& value_factory) const { + return absl::UnimplementedError( + "StructType::NewValueBuilder is unimplemented. Perhaps the value library " + "is not linked into your binary or StructType::NewValueBuilder was not " + "overridden?"); +} + } // namespace base_internal } // namespace cel diff --git a/base/types/struct_type.h b/base/types/struct_type.h index 3b3a0358b..e02216600 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -27,6 +27,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "base/handle.h" #include "base/internal/data.h" #include "base/kind.h" #include "base/memory.h" @@ -43,6 +44,8 @@ class MemoryManager; class StructValue; class TypedStructValueFactory; class TypeManager; +class StructValueBuilderInterface; +class ValueFactory; // StructType represents an struct type. An struct is a set of fields // that can be looked up by name and/or number. @@ -133,6 +136,10 @@ class StructType : public Type { absl::StatusOr> NewFieldIterator( MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + protected: static FieldId MakeFieldId(absl::string_view name) { return FieldId(name); } @@ -256,6 +263,10 @@ class LegacyStructType final : public StructType, public InlineData { absl::StatusOr> NewFieldIterator( MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: static constexpr uintptr_t kMetadata = kStoredInline | kTrivial | (static_cast(kKind) << kKindShift); @@ -279,7 +290,10 @@ class LegacyStructType final : public StructType, public InlineData { uintptr_t msg_; }; -class AbstractStructType : public StructType, public HeapData { +class AbstractStructType + : public StructType, + public HeapData, + public EnableHandleFromThis { public: static bool Is(const Type& type) { return StructType::Is(type) && @@ -313,6 +327,10 @@ class AbstractStructType : public StructType, public HeapData { virtual absl::StatusOr> NewFieldIterator( MemoryManager& memory_manager) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + virtual absl::StatusOr> + NewValueBuilder(ValueFactory& value_factory ABSL_ATTRIBUTE_LIFETIME_BOUND) + const ABSL_ATTRIBUTE_LIFETIME_BOUND; + protected: AbstractStructType(); diff --git a/base/values/list_value_builder_test.cc b/base/values/list_value_builder_test.cc index 06afbf0e9..7dadc3457 100644 --- a/base/values/list_value_builder_test.cc +++ b/base/values/list_value_builder_test.cc @@ -23,6 +23,9 @@ namespace cel { namespace { +using testing::NotNull; +using testing::WhenDynamicCastTo; + TEST(ListValueBuilder, Unspecialized) { TypeFactory type_factory(MemoryManager::Global()); TypeManager type_manager(type_factory, TypeProvider::Builtin()); @@ -256,6 +259,44 @@ TEST(ListValueBuilder, Timestamp) { EXPECT_EQ(element.As()->value(), absl::UnixEpoch() + absl::Minutes(1)); } +template +void TestListValueBuilderImpl(ValueFactory& value_factory, + const Handle& element) { + ASSERT_OK_AND_ASSIGN(auto type, + value_factory.type_factory().CreateListType(element)); + ASSERT_OK_AND_ASSIGN(auto builder, type->NewValueBuilder(value_factory)); + EXPECT_THAT((&builder.get()), WhenDynamicCastTo(NotNull())); +} + +TEST(ListValueBuilder, Dynamic) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); +#ifdef ABSL_INTERNAL_HAS_RTTI + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetBoolType())))); + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetIntType())))); + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetUintType())))); + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetDoubleType())))); + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetDurationType())))); + ASSERT_NO_FATAL_FAILURE( + ((TestListValueBuilderImpl>( + value_factory, type_factory.GetTimestampType())))); + ASSERT_NO_FATAL_FAILURE(((TestListValueBuilderImpl>( + value_factory, type_factory.GetDynType())))); +#else + GTEST_SKIP() << "RTTI unavailable"; +#endif +} } // namespace } // namespace cel diff --git a/base/values/map_value_builder_test.cc b/base/values/map_value_builder_test.cc index 2d077b9e4..3b063ceb2 100644 --- a/base/values/map_value_builder_test.cc +++ b/base/values/map_value_builder_test.cc @@ -25,6 +25,8 @@ namespace { using testing::IsFalse; using testing::IsTrue; +using testing::NotNull; +using testing::WhenDynamicCastTo; using cel::internal::IsOkAndHolds; TEST(MapValueBuilder, UnspecializedUnspecialized) { @@ -768,5 +770,140 @@ TEST(MapValueBuilder, UintTimestamp) { "[0u, 1u, 2u]"); } +template +void TestMapValueBuilderImpl(ValueFactory& value_factory, const Handle& key, + const Handle& value) { + ASSERT_OK_AND_ASSIGN(auto type, + value_factory.type_factory().CreateMapType(key, value)); + ASSERT_OK_AND_ASSIGN(auto builder, type->NewValueBuilder(value_factory)); + EXPECT_THAT((&builder.get()), WhenDynamicCastTo(NotNull())); +} + +TEST(MapValueBuilder, Dynamic) { + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); +#ifdef ABSL_INTERNAL_HAS_RTTI + // (BoolValue, ...) + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetBoolType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetIntType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetUintType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetDoubleType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetDurationType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetTimestampType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetBoolType(), + type_factory.GetDynType()))); + // (IntValue, ...) + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetBoolType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetIntType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetUintType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetDoubleType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetDurationType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetTimestampType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetIntType(), + type_factory.GetDynType()))); + // (UintValue, ...) + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetBoolType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetIntType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetUintType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetDoubleType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetDurationType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetTimestampType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetUintType(), + type_factory.GetDynType()))); + // (StringValue, ...) + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetBoolType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetIntType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetUintType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetDoubleType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetDurationType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetTimestampType()))); + ASSERT_NO_FATAL_FAILURE( + (TestMapValueBuilderImpl>( + value_factory, type_factory.GetStringType(), + type_factory.GetDynType()))); +#else + GTEST_SKIP() << "RTTI unavailable"; +#endif +} + } // namespace } // namespace cel diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 1e42f68c0..9590ff9e0 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -43,6 +43,8 @@ struct LegacyStructValueAccess; } class ValueFactory; +class StructValueBuilder; +class StructValueBuilderInterface; // StructValue represents an instance of cel::StructType. class StructValue : public Value { diff --git a/base/values/struct_value_builder.h b/base/values/struct_value_builder.h new file mode 100644 index 000000000..dccd4709c --- /dev/null +++ b/base/values/struct_value_builder.h @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_BUILDER_H_ + +#include "absl/strings/string_view.h" +#include "base/values/struct_value.h" + +namespace cel { + +class StructValueBuilderInterface { + public: + virtual ~StructValueBuilderInterface() = default; + + absl::Status SetField(StructValue::FieldId id, Handle value); + + virtual absl::Status SetFieldByName(absl::string_view name, + Handle value) = 0; + + virtual absl::Status SetFieldByNumber(int64_t number, + Handle value) = 0; + + virtual absl::StatusOr> Build() && = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUES_STRUCT_VALUE_BUILDER_H_ diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index cb68a61e5..abf4a4f0f 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -119,9 +119,9 @@ cc_test( ], deps = [ ":type", + "//base:data", "//base:kind", "//base:memory", - "//base:type", "//base/internal:memory_manager_testing", "//base/testing:type_matchers", "//extensions/protobuf/internal:testing", diff --git a/extensions/protobuf/struct_type_test.cc b/extensions/protobuf/struct_type_test.cc index 4499b656d..489dde991 100644 --- a/extensions/protobuf/struct_type_test.cc +++ b/extensions/protobuf/struct_type_test.cc @@ -26,6 +26,8 @@ #include "base/type_manager.h" #include "base/types/list_type.h" #include "base/types/map_type.h" +#include "base/value_factory.h" +#include "base/values/struct_value_builder.h" #include "extensions/protobuf/internal/testing.h" #include "extensions/protobuf/type.h" #include "extensions/protobuf/type_provider.h" @@ -263,6 +265,17 @@ TEST_P(ProtoStructTypeTest, NewFieldIteratorTypes) { // itself, which would not be useful. } +TEST_P(ProtoStructTypeTest, NewValueBuilderUnimplemented) { + TypeFactory type_factory(memory_manager()); + ProtoTypeProvider type_provider; + TypeManager type_manager(type_factory, type_provider); + ValueFactory value_factory(type_manager); + ASSERT_OK_AND_ASSIGN(auto type, + ProtoType::Resolve(type_manager)); + EXPECT_THAT(type->NewValueBuilder(value_factory), + StatusIs(absl::StatusCode::kUnimplemented)); +} + INSTANTIATE_TEST_SUITE_P(ProtoStructTypeTest, ProtoStructTypeTest, cel::base_internal::MemoryManagerTestModeAll(), cel::base_internal::MemoryManagerTestModeTupleName); From bb67412c11952bf8f3c9dafc39befd7fdfbf1373 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 22 May 2023 11:29:26 -0700 Subject: [PATCH 299/303] Add `TypeKind` and `ValueKind` PiperOrigin-RevId: 534130090 --- base/BUILD | 2 + base/kind.cc | 9 +++ base/kind.h | 156 +++++++++++++++++++++++++++++++++++++++++++++- base/kind_test.cc | 50 +++++++++++++++ 4 files changed, 216 insertions(+), 1 deletion(-) diff --git a/base/BUILD b/base/BUILD index 6c5211d69..8845804f2 100644 --- a/base/BUILD +++ b/base/BUILD @@ -65,7 +65,9 @@ cc_library( srcs = ["kind.cc"], hdrs = ["kind.h"], deps = [ + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) diff --git a/base/kind.cc b/base/kind.cc index fc37049ba..122295aef 100644 --- a/base/kind.cc +++ b/base/kind.cc @@ -16,6 +16,15 @@ namespace cel { +bool KindIsTypeKind(Kind kind) { + // Currently all Kind are valid TypeKind. + return true; +} + +bool KindIsValueKind(Kind kind) { + return kind != Kind::kWrapper && kind != Kind::kDyn && kind != Kind::kAny; +} + absl::string_view KindToString(Kind kind) { switch (kind) { case Kind::kNullType: diff --git a/base/kind.h b/base/kind.h index d870d42ac..9096f3577 100644 --- a/base/kind.h +++ b/base/kind.h @@ -15,14 +15,18 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ #define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ +#include + #include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/log/absl_check.h" #include "absl/strings/string_view.h" namespace cel { enum class Kind /* : uint8_t */ { // Must match legacy CelValue::Type. - kNullType = 0, + kNull = 0, kBool, kInt, kUint, @@ -46,6 +50,7 @@ enum class Kind /* : uint8_t */ { kOpaque, // Legacy aliases, deprecated do not use. + kNullType = kNull, kInt64 = kInt, kUint64 = kUint, kMessage = kStruct, @@ -59,6 +64,155 @@ enum class Kind /* : uint8_t */ { ABSL_ATTRIBUTE_PURE_FUNCTION absl::string_view KindToString(Kind kind); +// `TypeKind` is a subset of `Kind`, representing all valid `Kind` for `Type`. +// All `TypeKind` are valid `Kind`, but it is not guaranteed that all `Kind` are +// valid `TypeKind`. +enum class TypeKind : std::underlying_type_t { + kNull = static_cast(Kind::kNull), + kBool = static_cast(Kind::kBool), + kInt = static_cast(Kind::kInt), + kUint = static_cast(Kind::kUint), + kDouble = static_cast(Kind::kDouble), + kString = static_cast(Kind::kString), + kBytes = static_cast(Kind::kBytes), + kStruct = static_cast(Kind::kStruct), + kDuration = static_cast(Kind::kDuration), + kTimestamp = static_cast(Kind::kTimestamp), + kList = static_cast(Kind::kList), + kMap = static_cast(Kind::kMap), + kUnknown = static_cast(Kind::kUnknown), + kType = static_cast(Kind::kType), + kError = static_cast(Kind::kError), + kAny = static_cast(Kind::kAny), + kEnum = static_cast(Kind::kEnum), + kDyn = static_cast(Kind::kDyn), + kWrapper = static_cast(Kind::kWrapper), + kOpaque = static_cast(Kind::kOpaque), + + // Legacy aliases, deprecated do not use. + kNullType = kNull, + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + + // INTERNAL: Do not exceed 63. Implementation details rely on the fact that + // we can store `Kind` using 6 bits. + kNotForUseWithExhaustiveSwitchStatements = + static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), +}; + +inline Kind TypeKindToKind(TypeKind kind) { return absl::bit_cast(kind); } + +ABSL_ATTRIBUTE_PURE_FUNCTION bool KindIsTypeKind(Kind kind); + +inline bool operator==(Kind lhs, TypeKind rhs) { + return lhs == TypeKindToKind(rhs); +} + +inline bool operator==(TypeKind lhs, Kind rhs) { + return TypeKindToKind(lhs) == rhs; +} + +inline bool operator!=(Kind lhs, TypeKind rhs) { return !operator==(lhs, rhs); } + +inline bool operator!=(TypeKind lhs, Kind rhs) { return !operator==(lhs, rhs); } + +// `ValueKind` is a subset of `Kind`, representing all valid `Kind` for `Value`. +// All `ValueKind` are valid `Kind`, but it is not guaranteed that all `Kind` +// are valid `ValueKind`. +enum class ValueKind : std::underlying_type_t { + kNull = static_cast(Kind::kNull), + kBool = static_cast(Kind::kBool), + kInt = static_cast(Kind::kInt), + kUint = static_cast(Kind::kUint), + kDouble = static_cast(Kind::kDouble), + kString = static_cast(Kind::kString), + kBytes = static_cast(Kind::kBytes), + kStruct = static_cast(Kind::kStruct), + kDuration = static_cast(Kind::kDuration), + kTimestamp = static_cast(Kind::kTimestamp), + kList = static_cast(Kind::kList), + kMap = static_cast(Kind::kMap), + kUnknown = static_cast(Kind::kUnknown), + kType = static_cast(Kind::kType), + kError = static_cast(Kind::kError), + kEnum = static_cast(Kind::kEnum), + kOpaque = static_cast(Kind::kOpaque), + + // Legacy aliases, deprecated do not use. + kNullType = kNull, + kInt64 = kInt, + kUint64 = kUint, + kMessage = kStruct, + kUnknownSet = kUnknown, + kCelType = kType, + + // INTERNAL: Do not exceed 63. Implementation details rely on the fact that + // we can store `Kind` using 6 bits. + kNotForUseWithExhaustiveSwitchStatements = + static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), +}; + +inline Kind ValueKindToKind(ValueKind kind) { + return absl::bit_cast(kind); +} + +ABSL_ATTRIBUTE_PURE_FUNCTION bool KindIsValueKind(Kind kind); + +inline bool operator==(Kind lhs, ValueKind rhs) { + return lhs == ValueKindToKind(rhs); +} + +inline bool operator==(ValueKind lhs, Kind rhs) { + return ValueKindToKind(lhs) == rhs; +} + +inline bool operator==(TypeKind lhs, ValueKind rhs) { + return TypeKindToKind(lhs) == ValueKindToKind(rhs); +} + +inline bool operator==(ValueKind lhs, TypeKind rhs) { + return ValueKindToKind(lhs) == TypeKindToKind(rhs); +} + +inline bool operator!=(Kind lhs, ValueKind rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(ValueKind lhs, Kind rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(TypeKind lhs, ValueKind rhs) { + return !operator==(lhs, rhs); +} + +inline bool operator!=(ValueKind lhs, TypeKind rhs) { + return !operator==(lhs, rhs); +} + +inline absl::string_view TypeKindToString(TypeKind kind) { + // All TypeKind are valid Kind. + return KindToString(TypeKindToKind(kind)); +} + +inline absl::string_view ValueKindToString(ValueKind kind) { + // All ValueKind are valid Kind. + return KindToString(ValueKindToKind(kind)); +} + +inline TypeKind KindToTypeKind(Kind kind) { + ABSL_DCHECK(KindIsTypeKind(kind)) << KindToString(kind); + return absl::bit_cast(kind); +} + +inline ValueKind KindToValueKind(Kind kind) { + ABSL_DCHECK(KindIsValueKind(kind)) << KindToString(kind); + return absl::bit_cast(kind); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_KIND_H_ diff --git a/base/kind_test.cc b/base/kind_test.cc index f82ed0c9d..2fde907d5 100644 --- a/base/kind_test.cc +++ b/base/kind_test.cc @@ -46,5 +46,55 @@ TEST(Kind, ToString) { "*error*"); } +TEST(Kind, TypeKindRoundtrip) { + EXPECT_EQ(TypeKindToKind(KindToTypeKind(Kind::kBool)), Kind::kBool); +} + +TEST(Kind, ValueKindRoundtrip) { + EXPECT_EQ(ValueKindToKind(KindToValueKind(Kind::kBool)), Kind::kBool); +} + +TEST(Kind, IsTypeKind) { + EXPECT_TRUE(KindIsTypeKind(Kind::kBool)); + EXPECT_TRUE(KindIsTypeKind(Kind::kAny)); + EXPECT_TRUE(KindIsTypeKind(Kind::kDyn)); + EXPECT_TRUE(KindIsTypeKind(Kind::kWrapper)); +} + +TEST(Kind, IsValueKind) { + EXPECT_TRUE(KindIsValueKind(Kind::kBool)); + EXPECT_FALSE(KindIsValueKind(Kind::kAny)); + EXPECT_FALSE(KindIsValueKind(Kind::kDyn)); + EXPECT_FALSE(KindIsValueKind(Kind::kWrapper)); +} + +TEST(Kind, Equality) { + EXPECT_EQ(Kind::kBool, TypeKind::kBool); + EXPECT_EQ(TypeKind::kBool, Kind::kBool); + + EXPECT_EQ(Kind::kBool, ValueKind::kBool); + EXPECT_EQ(ValueKind::kBool, Kind::kBool); + + EXPECT_EQ(TypeKind::kBool, ValueKind::kBool); + EXPECT_EQ(ValueKind::kBool, TypeKind::kBool); + + EXPECT_NE(Kind::kBool, TypeKind::kInt); + EXPECT_NE(TypeKind::kInt, Kind::kBool); + + EXPECT_NE(Kind::kBool, ValueKind::kInt); + EXPECT_NE(ValueKind::kInt, Kind::kBool); + + EXPECT_NE(TypeKind::kBool, ValueKind::kInt); + EXPECT_NE(ValueKind::kInt, TypeKind::kBool); +} + +TEST(TypeKind, ToString) { + EXPECT_EQ(TypeKindToString(TypeKind::kBool), KindToString(Kind::kBool)); +} + +TEST(ValueKind, ToString) { + EXPECT_EQ(ValueKindToString(ValueKind::kBool), KindToString(Kind::kBool)); +} + } // namespace } // namespace cel From 7d3b4de1f52588c311fb5bde453063e6180ff492 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 22 May 2023 12:20:19 -0700 Subject: [PATCH 300/303] Update bit math constants for tagged messages for interop. Updated names should help readability since the tagging scheme is overloaded for both lite proto support and for handle migration. PiperOrigin-RevId: 534145645 --- base/internal/message_wrapper.h | 5 ++++- base/values/struct_value.cc | 4 +++- eval/internal/BUILD | 8 ++++---- eval/internal/interop.cc | 13 +++++++++--- eval/internal/interop_test.cc | 36 +++++++++++++++++++++++++++++++++ eval/public/message_wrapper.h | 26 ++++++++++++++++-------- 6 files changed, 75 insertions(+), 17 deletions(-) diff --git a/base/internal/message_wrapper.h b/base/internal/message_wrapper.h index 72491ede9..dc18f90c3 100644 --- a/base/internal/message_wrapper.h +++ b/base/internal/message_wrapper.h @@ -19,8 +19,11 @@ namespace cel::base_internal { -inline constexpr uintptr_t kMessageWrapperTagMask = 1 << 0; +inline constexpr uintptr_t kMessageWrapperTagMask = 0b1; inline constexpr uintptr_t kMessageWrapperPtrMask = ~kMessageWrapperTagMask; +inline constexpr uintptr_t kMessageWrapperTagSize = 1; +inline constexpr uintptr_t kMessageWrapperTagTypeInfoValue = 0b0; +inline constexpr uintptr_t kMessageWrapperTagMessageValue = 0b1; } // namespace cel::base_internal diff --git a/base/values/struct_value.cc b/base/values/struct_value.cc index 7c9a5dbfb..fccea1deb 100644 --- a/base/values/struct_value.cc +++ b/base/values/struct_value.cc @@ -179,11 +179,13 @@ class LegacyStructValueFieldIterator final : public StructValue::FieldIterator { }; Handle LegacyStructValue::type() const { - if ((msg_ & kMessageWrapperTagMask) == kMessageWrapperTagMask) { + uintptr_t tag = msg_ & kMessageWrapperTagMask; + if (tag == kMessageWrapperTagMessageValue) { // google::protobuf::Message return HandleFactory::Make(msg_); } // LegacyTypeInfoApis + ABSL_ASSERT(tag == kMessageWrapperTagTypeInfoValue); return HandleFactory::Make(type_info_); } diff --git a/eval/internal/BUILD b/eval/internal/BUILD index c91efc0ab..2ffebb574 100644 --- a/eval/internal/BUILD +++ b/eval/internal/BUILD @@ -22,8 +22,7 @@ cc_library( hdrs = ["interop.h"], deps = [ ":errors", - "//base:type", - "//base:value", + "//base:data", "//base/internal:message_wrapper", "//eval/public:cel_options", "//eval/public:cel_value", @@ -51,14 +50,15 @@ cc_test( deps = [ ":errors", ":interop", + "//base:data", "//base:memory", - "//base:type", - "//base:value", "//eval/public:cel_value", + "//eval/public:message_wrapper", "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:trivial_legacy_type_info", "//extensions/protobuf:memory_manager", "//extensions/protobuf:type", "//extensions/protobuf:value", diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index 4bc2a55fa..a2163f778 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -14,6 +14,7 @@ #include "eval/internal/interop.h" +#include #include #include #include @@ -43,6 +44,7 @@ #include "eval/public/unknown_set.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" +#include "google/protobuf/message.h" namespace cel::interop_internal { @@ -667,12 +669,17 @@ using ::google::api::expr::runtime::MessageWrapper; } // namespace absl::string_view MessageTypeName(uintptr_t msg) { - if ((msg & kMessageWrapperTagMask) != kMessageWrapperTagMask) { + uintptr_t tag = (msg & kMessageWrapperTagMask); + uintptr_t ptr = (msg & kMessageWrapperPtrMask); + + if (tag == kMessageWrapperTagTypeInfoValue) { // For google::protobuf::MessageLite, this is actually LegacyTypeInfoApis. - return reinterpret_cast(msg)->GetTypename( + return reinterpret_cast(ptr)->GetTypename( MessageWrapper()); } - return reinterpret_cast(msg & kMessageWrapperPtrMask) + ABSL_ASSERT(tag == kMessageWrapperTagMessageValue); + + return reinterpret_cast(ptr) ->GetDescriptor() ->full_name(); } diff --git a/eval/internal/interop_test.cc b/eval/internal/interop_test.cc index 50c29bda5..24d4e8b88 100644 --- a/eval/internal/interop_test.cc +++ b/eval/internal/interop_test.cc @@ -21,6 +21,7 @@ #include #include "google/protobuf/api.pb.h" +#include "google/protobuf/empty.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/strings/escaping.h" @@ -36,7 +37,9 @@ #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/unknown_set.h" #include "extensions/protobuf/memory_manager.h" #include "extensions/protobuf/type_provider.h" @@ -49,6 +52,7 @@ namespace { using ::google::api::expr::runtime::CelProtoWrapper; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; +using ::google::api::expr::runtime::MessageWrapper; using ::google::api::expr::runtime::UnknownSet; using testing::Eq; using testing::HasSubstr; @@ -763,6 +767,38 @@ TEST(ValueInterop, StructFromLegacy) { value_wrapper.legacy_type_info()); } +TEST(ValueInterop, StructFromLegacyMessageLite) { + google::protobuf::Arena arena; + extensions::ProtoMemoryManager memory_manager(&arena); + TypeFactory type_factory(memory_manager); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); + google::protobuf::Empty opaque; + MessageWrapper wrapper( + static_cast(&opaque), + google::api::expr::runtime::TrivialTypeInfo::GetInstance()); + CelValue legacy_value = CelValue::CreateMessageWrapper(wrapper); + ASSERT_OK_AND_ASSIGN(auto value, FromLegacyValue(&arena, legacy_value)); + EXPECT_EQ(value->kind(), Kind::kStruct); + EXPECT_EQ(value->type()->kind(), Kind::kStruct); + EXPECT_EQ(value->type()->name(), "opaque type"); + EXPECT_THAT( + value.As()->HasFieldByName( + StructValue::HasFieldContext(type_manager), "name"), + StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))); + EXPECT_THAT(value.As()->HasFieldByNumber( + StructValue::HasFieldContext(type_manager), 1), + StatusIs(absl::StatusCode::kUnimplemented)); + EXPECT_EQ(value.As()->DebugString(), "opaque type"); + auto value_wrapper = LegacyStructValueAccess::ToMessageWrapper( + *value.As()); + auto legacy_value_wrapper = legacy_value.MessageWrapperOrDie(); + EXPECT_EQ(legacy_value_wrapper.HasFullProto(), value_wrapper.HasFullProto()); + EXPECT_EQ(legacy_value_wrapper.message_ptr(), value_wrapper.message_ptr()); + EXPECT_EQ(legacy_value_wrapper.legacy_type_info(), + value_wrapper.legacy_type_info()); +} + TEST(ValueInterop, LegacyStructRoundtrip) { google::protobuf::Arena arena; extensions::ProtoMemoryManager memory_manager(&arena); diff --git a/eval/public/message_wrapper.h b/eval/public/message_wrapper.h index 95af72efa..ffa8648bc 100644 --- a/eval/public/message_wrapper.h +++ b/eval/public/message_wrapper.h @@ -45,18 +45,22 @@ class MessageWrapper { public: explicit Builder(google::protobuf::MessageLite* message) : message_ptr_(reinterpret_cast(message)) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); } explicit Builder(google::protobuf::Message* message) - : message_ptr_(reinterpret_cast(message) | kTagMask) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + : message_ptr_(reinterpret_cast(message) | kMessageTag) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); } google::protobuf::MessageLite* message_ptr() const { return reinterpret_cast(message_ptr_ & kPtrMask); } - bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } + bool HasFullProto() const { + return (message_ptr_ & kTagMask) == kMessageTag; + } MessageWrapper Build(const LegacyTypeInfoApis* type_info) { return MessageWrapper(message_ptr_, type_info); @@ -78,19 +82,21 @@ class MessageWrapper { const LegacyTypeInfoApis* legacy_type_info) : message_ptr_(reinterpret_cast(message)), legacy_type_info_(legacy_type_info) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); } MessageWrapper(const google::protobuf::Message* message, const LegacyTypeInfoApis* legacy_type_info) - : message_ptr_(reinterpret_cast(message) | kTagMask), + : message_ptr_(reinterpret_cast(message) | kMessageTag), legacy_type_info_(legacy_type_info) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= + kTagSize); } // If true, the message is using the full proto runtime and downcasting to // message should be safe. - bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } + bool HasFullProto() const { return (message_ptr_ & kTagMask) == kMessageTag; } // Returns the underlying message. // @@ -114,10 +120,14 @@ class MessageWrapper { Builder ToBuilder() { return Builder(message_ptr_); } + static constexpr uintptr_t kTagSize = + ::cel::base_internal::kMessageWrapperTagSize; static constexpr uintptr_t kTagMask = ::cel::base_internal::kMessageWrapperTagMask; static constexpr uintptr_t kPtrMask = ::cel::base_internal::kMessageWrapperPtrMask; + static constexpr uintptr_t kMessageTag = + ::cel::base_internal::kMessageWrapperTagMessageValue; uintptr_t message_ptr_; const LegacyTypeInfoApis* legacy_type_info_; }; From 51d9fd45eadb50d69d095d77bfdc5f3785e7d728 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 22 May 2023 13:40:14 -0700 Subject: [PATCH 301/303] Migrate `Type::kind()` to `TypeKind` PiperOrigin-RevId: 534168362 --- base/BUILD | 2 - base/internal/data.h | 2 + base/internal/type.h | 4 +- base/kind.cc | 9 -- base/kind.h | 62 ++++++---- base/type.cc | 186 ++++++++++++++-------------- base/type.h | 52 ++++---- base/type_test.cc | 52 ++++---- base/types/any_type.h | 4 +- base/types/bool_type.h | 4 +- base/types/bytes_type.h | 4 +- base/types/double_type.h | 4 +- base/types/duration_type.h | 5 +- base/types/dyn_type.h | 4 +- base/types/enum_type.h | 4 +- base/types/error_type.h | 4 +- base/types/int_type.h | 4 +- base/types/list_type.cc | 12 +- base/types/list_type.h | 6 +- base/types/map_type.cc | 18 +-- base/types/map_type.h | 6 +- base/types/null_type.h | 4 +- base/types/opaque_type.h | 4 +- base/types/string_type.h | 4 +- base/types/struct_type.h | 4 +- base/types/timestamp_type.h | 5 +- base/types/type_type.h | 4 +- base/types/uint_type.h | 4 +- base/types/unknown_type.h | 4 +- base/types/wrapper_type.h | 18 +-- base/value.h | 2 +- base/values/bytes_value.h | 2 +- base/values/enum_value.h | 2 +- base/values/error_value.h | 2 +- base/values/list_value.h | 2 +- base/values/map_value.h | 2 +- base/values/string_value.h | 2 +- base/values/type_value.h | 2 +- base/values/unknown_value.h | 2 +- extensions/protobuf/BUILD | 3 +- extensions/protobuf/struct_value.cc | 94 +++++++------- extensions/protobuf/value.cc | 32 ++--- 42 files changed, 325 insertions(+), 321 deletions(-) diff --git a/base/BUILD b/base/BUILD index 8845804f2..6c5211d69 100644 --- a/base/BUILD +++ b/base/BUILD @@ -65,9 +65,7 @@ cc_library( srcs = ["kind.cc"], hdrs = ["kind.h"], deps = [ - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) diff --git a/base/internal/data.h b/base/internal/data.h index 86d1d9621..d796e0717 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -184,6 +184,8 @@ class HeapData /* : public Data */ { : metadata_and_reference_count_(static_cast(kind) << kKindShift) {} + explicit HeapData(TypeKind kind) : HeapData(TypeKindToKind(kind)) {} + private: // Called by Arena-based memory managers to determine whether we actually need // our destructor called. Subclasses should override this if they want their diff --git a/base/internal/type.h b/base/internal/type.h index 567c28d2a..41603a170 100644 --- a/base/internal/type.h +++ b/base/internal/type.h @@ -45,7 +45,7 @@ class LegacyMapType; class ModernMapType; struct FieldIdFactory; -template +template class SimpleType; template class SimpleValue; @@ -108,7 +108,7 @@ struct TypeTraits; static_assert(!::std::is_abstract_v, "this must not be abstract"); \ \ bool derived::Is(const ::cel::Type& type) { \ - return type.kind() == ::cel::Kind::k##base && \ + return type.kind() == ::cel::TypeKind::k##base && \ ::cel::base_internal::Get##base##TypeTypeId( \ static_cast(type)) == \ ::cel::internal::TypeId(); \ diff --git a/base/kind.cc b/base/kind.cc index 122295aef..fc37049ba 100644 --- a/base/kind.cc +++ b/base/kind.cc @@ -16,15 +16,6 @@ namespace cel { -bool KindIsTypeKind(Kind kind) { - // Currently all Kind are valid TypeKind. - return true; -} - -bool KindIsValueKind(Kind kind) { - return kind != Kind::kWrapper && kind != Kind::kDyn && kind != Kind::kAny; -} - absl::string_view KindToString(Kind kind) { switch (kind) { case Kind::kNullType: diff --git a/base/kind.h b/base/kind.h index 9096f3577..085a7ee32 100644 --- a/base/kind.h +++ b/base/kind.h @@ -18,8 +18,7 @@ #include #include "absl/base/attributes.h" -#include "absl/base/casts.h" -#include "absl/log/absl_check.h" +#include "absl/base/macros.h" #include "absl/strings/string_view.h" namespace cel { @@ -103,21 +102,30 @@ enum class TypeKind : std::underlying_type_t { static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), }; -inline Kind TypeKindToKind(TypeKind kind) { return absl::bit_cast(kind); } +constexpr Kind TypeKindToKind(TypeKind kind) { + return static_cast(static_cast>(kind)); +} -ABSL_ATTRIBUTE_PURE_FUNCTION bool KindIsTypeKind(Kind kind); +constexpr bool KindIsTypeKind(Kind kind ABSL_ATTRIBUTE_UNUSED) { + // Currently all Kind are valid TypeKind. + return true; +} -inline bool operator==(Kind lhs, TypeKind rhs) { +constexpr bool operator==(Kind lhs, TypeKind rhs) { return lhs == TypeKindToKind(rhs); } -inline bool operator==(TypeKind lhs, Kind rhs) { +constexpr bool operator==(TypeKind lhs, Kind rhs) { return TypeKindToKind(lhs) == rhs; } -inline bool operator!=(Kind lhs, TypeKind rhs) { return !operator==(lhs, rhs); } +constexpr bool operator!=(Kind lhs, TypeKind rhs) { + return !operator==(lhs, rhs); +} -inline bool operator!=(TypeKind lhs, Kind rhs) { return !operator==(lhs, rhs); } +constexpr bool operator!=(TypeKind lhs, Kind rhs) { + return !operator==(lhs, rhs); +} // `ValueKind` is a subset of `Kind`, representing all valid `Kind` for `Value`. // All `ValueKind` are valid `Kind`, but it is not guaranteed that all `Kind` @@ -155,41 +163,44 @@ enum class ValueKind : std::underlying_type_t { static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), }; -inline Kind ValueKindToKind(ValueKind kind) { - return absl::bit_cast(kind); +constexpr Kind ValueKindToKind(ValueKind kind) { + return static_cast( + static_cast>(kind)); } -ABSL_ATTRIBUTE_PURE_FUNCTION bool KindIsValueKind(Kind kind); +constexpr bool KindIsValueKind(Kind kind) { + return kind != Kind::kWrapper && kind != Kind::kDyn && kind != Kind::kAny; +} -inline bool operator==(Kind lhs, ValueKind rhs) { +constexpr bool operator==(Kind lhs, ValueKind rhs) { return lhs == ValueKindToKind(rhs); } -inline bool operator==(ValueKind lhs, Kind rhs) { +constexpr bool operator==(ValueKind lhs, Kind rhs) { return ValueKindToKind(lhs) == rhs; } -inline bool operator==(TypeKind lhs, ValueKind rhs) { +constexpr bool operator==(TypeKind lhs, ValueKind rhs) { return TypeKindToKind(lhs) == ValueKindToKind(rhs); } -inline bool operator==(ValueKind lhs, TypeKind rhs) { +constexpr bool operator==(ValueKind lhs, TypeKind rhs) { return ValueKindToKind(lhs) == TypeKindToKind(rhs); } -inline bool operator!=(Kind lhs, ValueKind rhs) { +constexpr bool operator!=(Kind lhs, ValueKind rhs) { return !operator==(lhs, rhs); } -inline bool operator!=(ValueKind lhs, Kind rhs) { +constexpr bool operator!=(ValueKind lhs, Kind rhs) { return !operator==(lhs, rhs); } -inline bool operator!=(TypeKind lhs, ValueKind rhs) { +constexpr bool operator!=(TypeKind lhs, ValueKind rhs) { return !operator==(lhs, rhs); } -inline bool operator!=(ValueKind lhs, TypeKind rhs) { +constexpr bool operator!=(ValueKind lhs, TypeKind rhs) { return !operator==(lhs, rhs); } @@ -203,14 +214,15 @@ inline absl::string_view ValueKindToString(ValueKind kind) { return KindToString(ValueKindToKind(kind)); } -inline TypeKind KindToTypeKind(Kind kind) { - ABSL_DCHECK(KindIsTypeKind(kind)) << KindToString(kind); - return absl::bit_cast(kind); +constexpr TypeKind KindToTypeKind(Kind kind) { + ABSL_ASSERT(KindIsTypeKind(kind)); + return static_cast(static_cast>(kind)); } -inline ValueKind KindToValueKind(Kind kind) { - ABSL_DCHECK(KindIsValueKind(kind)) << KindToString(kind); - return absl::bit_cast(kind); +constexpr ValueKind KindToValueKind(Kind kind) { + ABSL_ASSERT(KindIsValueKind(kind)); + return static_cast( + static_cast>(kind)); } } // namespace cel diff --git a/base/type.cc b/base/type.cc index 281318422..f7f96a07b 100644 --- a/base/type.cc +++ b/base/type.cc @@ -48,45 +48,45 @@ CEL_INTERNAL_TYPE_IMPL(Type); absl::string_view Type::name() const { switch (kind()) { - case Kind::kNullType: + case TypeKind::kNullType: return static_cast(this)->name(); - case Kind::kError: + case TypeKind::kError: return static_cast(this)->name(); - case Kind::kDyn: + case TypeKind::kDyn: return static_cast(this)->name(); - case Kind::kAny: + case TypeKind::kAny: return static_cast(this)->name(); - case Kind::kType: + case TypeKind::kType: return static_cast(this)->name(); - case Kind::kBool: + case TypeKind::kBool: return static_cast(this)->name(); - case Kind::kInt: + case TypeKind::kInt: return static_cast(this)->name(); - case Kind::kUint: + case TypeKind::kUint: return static_cast(this)->name(); - case Kind::kDouble: + case TypeKind::kDouble: return static_cast(this)->name(); - case Kind::kString: + case TypeKind::kString: return static_cast(this)->name(); - case Kind::kBytes: + case TypeKind::kBytes: return static_cast(this)->name(); - case Kind::kEnum: + case TypeKind::kEnum: return static_cast(this)->name(); - case Kind::kDuration: + case TypeKind::kDuration: return static_cast(this)->name(); - case Kind::kTimestamp: + case TypeKind::kTimestamp: return static_cast(this)->name(); - case Kind::kList: + case TypeKind::kList: return static_cast(this)->name(); - case Kind::kMap: + case TypeKind::kMap: return static_cast(this)->name(); - case Kind::kStruct: + case TypeKind::kStruct: return static_cast(this)->name(); - case Kind::kUnknown: + case TypeKind::kUnknown: return static_cast(this)->name(); - case Kind::kWrapper: + case TypeKind::kWrapper: return static_cast(this)->name(); - case Kind::kOpaque: + case TypeKind::kOpaque: return static_cast(this)->name(); default: return "*unreachable*"; @@ -95,13 +95,13 @@ absl::string_view Type::name() const { absl::Span Type::aliases() const { switch (kind()) { - case Kind::kDyn: + case TypeKind::kDyn: return static_cast(this)->aliases(); - case Kind::kList: + case TypeKind::kList: return static_cast(this)->aliases(); - case Kind::kMap: + case TypeKind::kMap: return static_cast(this)->aliases(); - case Kind::kWrapper: + case TypeKind::kWrapper: return static_cast(this)->aliases(); default: // Everything else does not support aliases. @@ -111,102 +111,102 @@ absl::Span Type::aliases() const { std::string Type::DebugString() const { switch (kind()) { - case Kind::kNullType: + case TypeKind::kNullType: return static_cast(this)->DebugString(); - case Kind::kError: + case TypeKind::kError: return static_cast(this)->DebugString(); - case Kind::kDyn: + case TypeKind::kDyn: return static_cast(this)->DebugString(); - case Kind::kAny: + case TypeKind::kAny: return static_cast(this)->DebugString(); - case Kind::kType: + case TypeKind::kType: return static_cast(this)->DebugString(); - case Kind::kBool: + case TypeKind::kBool: return static_cast(this)->DebugString(); - case Kind::kInt: + case TypeKind::kInt: return static_cast(this)->DebugString(); - case Kind::kUint: + case TypeKind::kUint: return static_cast(this)->DebugString(); - case Kind::kDouble: + case TypeKind::kDouble: return static_cast(this)->DebugString(); - case Kind::kString: + case TypeKind::kString: return static_cast(this)->DebugString(); - case Kind::kBytes: + case TypeKind::kBytes: return static_cast(this)->DebugString(); - case Kind::kEnum: + case TypeKind::kEnum: return static_cast(this)->DebugString(); - case Kind::kDuration: + case TypeKind::kDuration: return static_cast(this)->DebugString(); - case Kind::kTimestamp: + case TypeKind::kTimestamp: return static_cast(this)->DebugString(); - case Kind::kList: + case TypeKind::kList: return static_cast(this)->DebugString(); - case Kind::kMap: + case TypeKind::kMap: return static_cast(this)->DebugString(); - case Kind::kStruct: + case TypeKind::kStruct: return static_cast(this)->DebugString(); - case Kind::kUnknown: + case TypeKind::kUnknown: return static_cast(this)->DebugString(); - case Kind::kWrapper: + case TypeKind::kWrapper: return static_cast(this)->DebugString(); - case Kind::kOpaque: + case TypeKind::kOpaque: return static_cast(this)->DebugString(); default: return "*unreachable*"; } } -bool Type::Equals(const Type& lhs, const Type& rhs, Kind kind) { +bool Type::Equals(const Type& lhs, const Type& rhs, TypeKind kind) { if (&lhs == &rhs) { return true; } switch (kind) { - case Kind::kNullType: + case TypeKind::kNullType: return true; - case Kind::kError: + case TypeKind::kError: return true; - case Kind::kDyn: + case TypeKind::kDyn: return true; - case Kind::kAny: + case TypeKind::kAny: return true; - case Kind::kType: + case TypeKind::kType: return true; - case Kind::kBool: + case TypeKind::kBool: return true; - case Kind::kInt: + case TypeKind::kInt: return true; - case Kind::kUint: + case TypeKind::kUint: return true; - case Kind::kDouble: + case TypeKind::kDouble: return true; - case Kind::kString: + case TypeKind::kString: return true; - case Kind::kBytes: + case TypeKind::kBytes: return true; - case Kind::kEnum: + case TypeKind::kEnum: return static_cast(lhs).name() == static_cast(rhs).name(); - case Kind::kDuration: + case TypeKind::kDuration: return true; - case Kind::kTimestamp: + case TypeKind::kTimestamp: return true; - case Kind::kList: + case TypeKind::kList: return static_cast(lhs).element() == static_cast(rhs).element(); - case Kind::kMap: + case TypeKind::kMap: return static_cast(lhs).key() == static_cast(rhs).key() && static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kStruct: + case TypeKind::kStruct: return static_cast(lhs).name() == static_cast(rhs).name(); - case Kind::kUnknown: + case TypeKind::kUnknown: return true; - case Kind::kWrapper: + case TypeKind::kWrapper: return static_cast(lhs).wrapped() == static_cast(rhs).wrapped(); - case Kind::kOpaque: { + case TypeKind::kOpaque: { if (static_cast(lhs).name() != static_cast(rhs).name()) { return false; @@ -224,89 +224,89 @@ bool Type::Equals(const Type& lhs, const Type& rhs, Kind kind) { } } -void Type::HashValue(const Type& type, Kind kind, absl::HashState state) { +void Type::HashValue(const Type& type, TypeKind kind, absl::HashState state) { switch (kind) { - case Kind::kNullType: + case TypeKind::kNullType: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kError: + case TypeKind::kError: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kDyn: + case TypeKind::kDyn: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kAny: + case TypeKind::kAny: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kType: + case TypeKind::kType: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kBool: + case TypeKind::kBool: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kInt: + case TypeKind::kInt: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kUint: + case TypeKind::kUint: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kDouble: + case TypeKind::kDouble: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kString: + case TypeKind::kString: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kBytes: + case TypeKind::kBytes: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kEnum: + case TypeKind::kEnum: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kDuration: + case TypeKind::kDuration: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kTimestamp: + case TypeKind::kTimestamp: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kList: + case TypeKind::kList: absl::HashState::combine(std::move(state), static_cast(type).element(), kind, static_cast(type).name()); return; - case Kind::kMap: + case TypeKind::kMap: absl::HashState::combine(std::move(state), static_cast(type).key(), static_cast(type).value(), kind, static_cast(type).name()); return; - case Kind::kStruct: + case TypeKind::kStruct: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kUnknown: + case TypeKind::kUnknown: absl::HashState::combine(std::move(state), kind, static_cast(type).name()); return; - case Kind::kWrapper: + case TypeKind::kWrapper: absl::HashState::combine( std::move(state), static_cast(type).wrapped(), kind, static_cast(type).name()); return; - case Kind::kOpaque: { + case TypeKind::kOpaque: { const auto& parameters = static_cast(type).parameters(); for (const auto& parameter : parameters) { @@ -338,7 +338,7 @@ bool TypeHandle::Equals(const TypeHandle& other) const { if (self == nullptr || that == nullptr) { return false; } - Kind kind = self->kind(); + TypeKind kind = self->kind(); return kind == that->kind() && Type::Equals(*self, *that, kind); } @@ -418,23 +418,23 @@ void TypeHandle::Destruct() { } void TypeHandle::Delete() const { - switch (data_.kind_heap()) { - case Kind::kList: + switch (KindToTypeKind(data_.kind_heap())) { + case TypeKind::kList: delete static_cast( static_cast(static_cast(data_.get_heap()))); return; - case Kind::kMap: + case TypeKind::kMap: delete static_cast( static_cast(static_cast(data_.get_heap()))); return; - case Kind::kEnum: + case TypeKind::kEnum: delete static_cast(static_cast(data_.get_heap())); return; - case Kind::kStruct: + case TypeKind::kStruct: delete static_cast( static_cast(data_.get_heap())); return; - case Kind::kOpaque: + case TypeKind::kOpaque: delete static_cast(static_cast(data_.get_heap())); return; default: diff --git a/base/type.h b/base/type.h index caee5e043..c77d647d8 100644 --- a/base/type.h +++ b/base/type.h @@ -54,7 +54,9 @@ class Type : public base_internal::Data { static const Type& Cast(const Type& type) { return type; } // Returns the type kind. - Kind kind() const { return base_internal::Metadata::Kind(*this); } + TypeKind kind() const { + return KindToTypeKind(base_internal::Metadata::Kind(*this)); + } // Returns the type name, i.e. "list". absl::string_view name() const; @@ -96,7 +98,7 @@ class Type : public base_internal::Data { friend class StructType; friend class ListType; friend class MapType; - template + template friend class base_internal::SimpleType; friend class WrapperType; friend class base_internal::TypeHandle; @@ -107,17 +109,17 @@ class Type : public base_internal::Data { // doesn't exist. absl::Span aliases() const; - static bool Equals(const Type& lhs, const Type& rhs, Kind kind); + static bool Equals(const Type& lhs, const Type& rhs, TypeKind kind); static bool Equals(const Type& lhs, const Type& rhs) { if (&lhs == &rhs) { return true; } - Kind lhs_kind = lhs.kind(); + TypeKind lhs_kind = lhs.kind(); return lhs_kind == rhs.kind() && Equals(lhs, rhs, lhs_kind); } - static void HashValue(const Type& type, Kind kind, absl::HashState state); + static void HashValue(const Type& type, TypeKind kind, absl::HashState state); static void HashValue(const Type& type, absl::HashState state) { HashValue(type, type.kind(), std::move(state)); @@ -198,9 +200,9 @@ class TypeHandle final { void HashValue(absl::HashState state) const; private: - static bool Equals(const Type& lhs, const Type& rhs, Kind kind); + static bool Equals(const Type& lhs, const Type& rhs, TypeKind kind); - static void HashValue(const Type& type, Kind kind, absl::HashState state); + static void HashValue(const Type& type, TypeKind kind, absl::HashState state); void CopyFrom(const TypeHandle& other); @@ -251,83 +253,83 @@ struct HandleTraits && !std::is_same_v)>> final : public HandleTraits {}; -template +template struct SimpleTypeName; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "null_type"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "*error*"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "dyn"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "google.protobuf.Any"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "bool"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "int"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "uint"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "double"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "bytes"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "string"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "google.protobuf.Duration"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "google.protobuf.Timestamp"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "type"; }; template <> -struct SimpleTypeName { +struct SimpleTypeName { static constexpr absl::string_view value = "*unknown*"; }; -template +template class SimpleType : public Type, public InlineData { public: - static constexpr Kind kKind = K; + static constexpr TypeKind kKind = K; static constexpr absl::string_view kName = SimpleTypeName::value; static bool Is(const Type& type) { return type.kind() == kKind; } @@ -341,7 +343,7 @@ class SimpleType : public Type, public InlineData { SimpleType& operator=(const SimpleType&) = default; SimpleType& operator=(SimpleType&&) = default; - constexpr Kind kind() const { return kKind; } + constexpr TypeKind kind() const { return kKind; } constexpr absl::string_view name() const { return kName; } diff --git a/base/type_test.cc b/base/type_test.cc index bacc1557f..77761e0b5 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -26,8 +26,6 @@ #include "base/type_factory.h" #include "base/type_manager.h" #include "base/value.h" -#include "base/values/enum_value.h" -#include "base/values/struct_value.h" #include "internal/testing.h" namespace cel { @@ -296,84 +294,84 @@ void TestTypeIs(const Handle& type) { TEST_P(TypeTest, Null) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetNullType()->kind(), Kind::kNullType); + EXPECT_EQ(type_factory.GetNullType()->kind(), TypeKind::kNullType); EXPECT_EQ(type_factory.GetNullType()->name(), "null_type"); TestTypeIs(type_factory.GetNullType()); } TEST_P(TypeTest, Error) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetErrorType()->kind(), Kind::kError); + EXPECT_EQ(type_factory.GetErrorType()->kind(), TypeKind::kError); EXPECT_EQ(type_factory.GetErrorType()->name(), "*error*"); TestTypeIs(type_factory.GetErrorType()); } TEST_P(TypeTest, Dyn) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetDynType()->kind(), Kind::kDyn); + EXPECT_EQ(type_factory.GetDynType()->kind(), TypeKind::kDyn); EXPECT_EQ(type_factory.GetDynType()->name(), "dyn"); TestTypeIs(type_factory.GetDynType()); } TEST_P(TypeTest, Any) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetAnyType()->kind(), Kind::kAny); + EXPECT_EQ(type_factory.GetAnyType()->kind(), TypeKind::kAny); EXPECT_EQ(type_factory.GetAnyType()->name(), "google.protobuf.Any"); TestTypeIs(type_factory.GetAnyType()); } TEST_P(TypeTest, Bool) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetBoolType()->kind(), Kind::kBool); + EXPECT_EQ(type_factory.GetBoolType()->kind(), TypeKind::kBool); EXPECT_EQ(type_factory.GetBoolType()->name(), "bool"); TestTypeIs(type_factory.GetBoolType()); } TEST_P(TypeTest, Int) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetIntType()->kind(), Kind::kInt); + EXPECT_EQ(type_factory.GetIntType()->kind(), TypeKind::kInt); EXPECT_EQ(type_factory.GetIntType()->name(), "int"); TestTypeIs(type_factory.GetIntType()); } TEST_P(TypeTest, Uint) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetUintType()->kind(), Kind::kUint); + EXPECT_EQ(type_factory.GetUintType()->kind(), TypeKind::kUint); EXPECT_EQ(type_factory.GetUintType()->name(), "uint"); TestTypeIs(type_factory.GetUintType()); } TEST_P(TypeTest, Double) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetDoubleType()->kind(), Kind::kDouble); + EXPECT_EQ(type_factory.GetDoubleType()->kind(), TypeKind::kDouble); EXPECT_EQ(type_factory.GetDoubleType()->name(), "double"); TestTypeIs(type_factory.GetDoubleType()); } TEST_P(TypeTest, String) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetStringType()->kind(), Kind::kString); + EXPECT_EQ(type_factory.GetStringType()->kind(), TypeKind::kString); EXPECT_EQ(type_factory.GetStringType()->name(), "string"); TestTypeIs(type_factory.GetStringType()); } TEST_P(TypeTest, Bytes) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetBytesType()->kind(), Kind::kBytes); + EXPECT_EQ(type_factory.GetBytesType()->kind(), TypeKind::kBytes); EXPECT_EQ(type_factory.GetBytesType()->name(), "bytes"); TestTypeIs(type_factory.GetBytesType()); } TEST_P(TypeTest, Duration) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetDurationType()->kind(), Kind::kDuration); + EXPECT_EQ(type_factory.GetDurationType()->kind(), TypeKind::kDuration); EXPECT_EQ(type_factory.GetDurationType()->name(), "google.protobuf.Duration"); TestTypeIs(type_factory.GetDurationType()); } TEST_P(TypeTest, Timestamp) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetTimestampType()->kind(), Kind::kTimestamp); + EXPECT_EQ(type_factory.GetTimestampType()->kind(), TypeKind::kTimestamp); EXPECT_EQ(type_factory.GetTimestampType()->name(), "google.protobuf.Timestamp"); TestTypeIs(type_factory.GetTimestampType()); @@ -383,7 +381,7 @@ TEST_P(TypeTest, Enum) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); - EXPECT_EQ(enum_type->kind(), Kind::kEnum); + EXPECT_EQ(enum_type->kind(), TypeKind::kEnum); EXPECT_EQ(enum_type->name(), "test_enum.TestEnum"); TestTypeIs(enum_type); } @@ -394,7 +392,7 @@ TEST_P(TypeTest, Struct) { ASSERT_OK_AND_ASSIGN( auto struct_type, type_manager.type_factory().CreateStructType()); - EXPECT_EQ(struct_type->kind(), Kind::kStruct); + EXPECT_EQ(struct_type->kind(), TypeKind::kStruct); EXPECT_EQ(struct_type->name(), "test_struct.TestStruct"); TestTypeIs(struct_type); } @@ -405,7 +403,7 @@ TEST_P(TypeTest, List) { type_factory.CreateListType(type_factory.GetBoolType())); EXPECT_EQ(list_type, Must(type_factory.CreateListType(type_factory.GetBoolType()))); - EXPECT_EQ(list_type->kind(), Kind::kList); + EXPECT_EQ(list_type->kind(), TypeKind::kList); EXPECT_EQ(list_type->name(), "list"); EXPECT_EQ(list_type->element(), type_factory.GetBoolType()); TestTypeIs(list_type); @@ -422,7 +420,7 @@ TEST_P(TypeTest, Map) { EXPECT_NE(map_type, Must(type_factory.CreateMapType(type_factory.GetBoolType(), type_factory.GetStringType()))); - EXPECT_EQ(map_type->kind(), Kind::kMap); + EXPECT_EQ(map_type->kind(), TypeKind::kMap); EXPECT_EQ(map_type->name(), "map"); EXPECT_EQ(map_type->key(), type_factory.GetStringType()); EXPECT_EQ(map_type->value(), type_factory.GetBoolType()); @@ -431,14 +429,14 @@ TEST_P(TypeTest, Map) { TEST_P(TypeTest, TypeType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetTypeType()->kind(), Kind::kType); + EXPECT_EQ(type_factory.GetTypeType()->kind(), TypeKind::kType); EXPECT_EQ(type_factory.GetTypeType()->name(), "type"); TestTypeIs(type_factory.GetTypeType()); } TEST_P(TypeTest, UnknownType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetUnknownType()->kind(), Kind::kUnknown); + EXPECT_EQ(type_factory.GetUnknownType()->kind(), TypeKind::kUnknown); EXPECT_EQ(type_factory.GetUnknownType()->name(), "*unknown*"); TestTypeIs(type_factory.GetUnknownType()); } @@ -447,7 +445,7 @@ TEST_P(TypeTest, OptionalType) { TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto optional_type, type_factory.CreateOptionalType( type_factory.GetStringType())); - EXPECT_EQ(optional_type->kind(), Kind::kOpaque); + EXPECT_EQ(optional_type->kind(), TypeKind::kOpaque); EXPECT_EQ(optional_type->name(), "optional"); TestTypeIs(optional_type); TestTypeIs(optional_type->type().As()); @@ -455,7 +453,7 @@ TEST_P(TypeTest, OptionalType) { TEST_P(TypeTest, BoolWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetBoolWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetBoolWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetBoolWrapperType()->name(), "google.protobuf.BoolValue"); TestTypeIs(type_factory.GetBoolWrapperType()); @@ -463,7 +461,7 @@ TEST_P(TypeTest, BoolWrapperType) { TEST_P(TypeTest, ByteWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetBytesWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetBytesWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetBytesWrapperType()->name(), "google.protobuf.BytesValue"); TestTypeIs(type_factory.GetBytesWrapperType()); @@ -471,7 +469,7 @@ TEST_P(TypeTest, ByteWrapperType) { TEST_P(TypeTest, DoubleWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetDoubleWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetDoubleWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetDoubleWrapperType()->name(), "google.protobuf.DoubleValue"); TestTypeIs(type_factory.GetDoubleWrapperType()); @@ -479,7 +477,7 @@ TEST_P(TypeTest, DoubleWrapperType) { TEST_P(TypeTest, IntWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetIntWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetIntWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetIntWrapperType()->name(), "google.protobuf.Int64Value"); TestTypeIs(type_factory.GetIntWrapperType()); @@ -487,7 +485,7 @@ TEST_P(TypeTest, IntWrapperType) { TEST_P(TypeTest, StringWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetStringWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetStringWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetStringWrapperType()->name(), "google.protobuf.StringValue"); TestTypeIs(type_factory.GetStringWrapperType()); @@ -495,7 +493,7 @@ TEST_P(TypeTest, StringWrapperType) { TEST_P(TypeTest, UintWrapperType) { TypeFactory type_factory(memory_manager()); - EXPECT_EQ(type_factory.GetUintWrapperType()->kind(), Kind::kWrapper); + EXPECT_EQ(type_factory.GetUintWrapperType()->kind(), TypeKind::kWrapper); EXPECT_EQ(type_factory.GetUintWrapperType()->name(), "google.protobuf.UInt64Value"); TestTypeIs(type_factory.GetUintWrapperType()); diff --git a/base/types/any_type.h b/base/types/any_type.h index 96f21d582..85741cd10 100644 --- a/base/types/any_type.h +++ b/base/types/any_type.h @@ -23,9 +23,9 @@ namespace cel { class AnyValue; -class AnyType final : public base_internal::SimpleType { +class AnyType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/bool_type.h b/base/types/bool_type.h index 45667652c..afccf29fb 100644 --- a/base/types/bool_type.h +++ b/base/types/bool_type.h @@ -24,9 +24,9 @@ namespace cel { class BoolValue; class BoolWrapperType; -class BoolType final : public base_internal::SimpleType { +class BoolType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/bytes_type.h b/base/types/bytes_type.h index ce097d4e5..fb7413822 100644 --- a/base/types/bytes_type.h +++ b/base/types/bytes_type.h @@ -24,9 +24,9 @@ namespace cel { class BytesValue; class BytesWrapperType; -class BytesType final : public base_internal::SimpleType { +class BytesType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/double_type.h b/base/types/double_type.h index 61b158433..5722b905d 100644 --- a/base/types/double_type.h +++ b/base/types/double_type.h @@ -24,9 +24,9 @@ namespace cel { class DoubleValue; class DoubleWrapperType; -class DoubleType final : public base_internal::SimpleType { +class DoubleType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/duration_type.h b/base/types/duration_type.h index 0b3ea5a34..147b8dd69 100644 --- a/base/types/duration_type.h +++ b/base/types/duration_type.h @@ -23,9 +23,10 @@ namespace cel { class DurationValue; -class DurationType final : public base_internal::SimpleType { +class DurationType final + : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/dyn_type.h b/base/types/dyn_type.h index 1dc509697..4db9cef21 100644 --- a/base/types/dyn_type.h +++ b/base/types/dyn_type.h @@ -25,9 +25,9 @@ namespace cel { class DynValue; -class DynType final : public base_internal::SimpleType { +class DynType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/enum_type.h b/base/types/enum_type.h index 7d8e4c314..a74791342 100644 --- a/base/types/enum_type.h +++ b/base/types/enum_type.h @@ -85,7 +85,7 @@ class EnumType : public Type, public base_internal::HeapData { absl::variant data_; }; - static constexpr Kind kKind = Kind::kEnum; + static constexpr TypeKind kKind = TypeKind::kEnum; using Type::Is; @@ -96,7 +96,7 @@ class EnumType : public Type, public base_internal::HeapData { return static_cast(type); } - Kind kind() const { return kKind; } + TypeKind kind() const { return kKind; } virtual absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; diff --git a/base/types/error_type.h b/base/types/error_type.h index 040e28888..30ae75df9 100644 --- a/base/types/error_type.h +++ b/base/types/error_type.h @@ -23,9 +23,9 @@ namespace cel { class ErrorValue; -class ErrorType final : public base_internal::SimpleType { +class ErrorType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/int_type.h b/base/types/int_type.h index 3bc1f92e9..fb2138dd6 100644 --- a/base/types/int_type.h +++ b/base/types/int_type.h @@ -24,9 +24,9 @@ namespace cel { class IntValue; class IntWrapperType; -class IntType final : public base_internal::SimpleType { +class IntType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/list_type.cc b/base/types/list_type.cc index 777b46333..f552a7aa8 100644 --- a/base/types/list_type.cc +++ b/base/types/list_type.cc @@ -56,27 +56,27 @@ const Handle& ListType::element() const { absl::StatusOr> ListType::NewValueBuilder( ValueFactory& value_factory) const { switch (element()->kind()) { - case Kind::kBool: + case TypeKind::kBool: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); - case Kind::kInt: + case TypeKind::kInt: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); - case Kind::kUint: + case TypeKind::kUint: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); - case Kind::kDouble: + case TypeKind::kDouble: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); - case Kind::kDuration: + case TypeKind::kDuration: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); - case Kind::kTimestamp: + case TypeKind::kTimestamp: return MakeUnique>( value_factory.memory_manager(), base_internal::kComposedListType, value_factory, handle_from_this()); diff --git a/base/types/list_type.h b/base/types/list_type.h index d60cd516b..6cd1e512c 100644 --- a/base/types/list_type.h +++ b/base/types/list_type.h @@ -43,13 +43,13 @@ class ListType : public Type, public base_internal::EnableHandleFromThis { public: - static constexpr Kind kKind = Kind::kList; + static constexpr TypeKind kKind = TypeKind::kList; static bool Is(const Type& type) { return type.kind() == kKind; } - Kind kind() const { return kKind; } + TypeKind kind() const { return kKind; } - absl::string_view name() const { return KindToString(kind()); } + absl::string_view name() const { return TypeKindToString(kind()); } std::string DebugString() const; diff --git a/base/types/map_type.cc b/base/types/map_type.cc index 12c2b3d23..2f3691fe5 100644 --- a/base/types/map_type.cc +++ b/base/types/map_type.cc @@ -67,22 +67,22 @@ template absl::StatusOr> NewMapValueBuilderFor( ValueFactory& value_factory, Handle type) { switch (type->value()->kind()) { - case Kind::kBool: + case TypeKind::kBool: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); - case Kind::kInt: + case TypeKind::kInt: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); - case Kind::kUint: + case TypeKind::kUint: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); - case Kind::kDouble: + case TypeKind::kDouble: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); - case Kind::kDuration: + case TypeKind::kDuration: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); - case Kind::kTimestamp: + case TypeKind::kTimestamp: return MakeUnique>( value_factory.memory_manager(), value_factory, std::move(type)); default: @@ -96,12 +96,12 @@ absl::StatusOr> NewMapValueBuilderFor( absl::StatusOr> MapType::NewValueBuilder( ValueFactory& value_factory) const { switch (key()->kind()) { - case Kind::kBool: + case TypeKind::kBool: return NewMapValueBuilderFor(value_factory, handle_from_this()); - case Kind::kInt: + case TypeKind::kInt: return NewMapValueBuilderFor(value_factory, handle_from_this()); - case Kind::kUint: + case TypeKind::kUint: return NewMapValueBuilderFor(value_factory, handle_from_this()); default: diff --git a/base/types/map_type.h b/base/types/map_type.h index 14aba4557..345300c8f 100644 --- a/base/types/map_type.h +++ b/base/types/map_type.h @@ -43,7 +43,7 @@ class MapValueBuilderInterface; class MapType : public Type, public base_internal::EnableHandleFromThis { public: - static constexpr Kind kKind = Kind::kMap; + static constexpr TypeKind kKind = TypeKind::kMap; static bool Is(const Type& type) { return type.kind() == kKind; } @@ -54,9 +54,9 @@ class MapType : public Type, return static_cast(type); } - Kind kind() const { return kKind; } + TypeKind kind() const { return kKind; } - absl::string_view name() const { return KindToString(kind()); } + absl::string_view name() const { return TypeKindToString(kind()); } std::string DebugString() const; diff --git a/base/types/null_type.h b/base/types/null_type.h index d733ef047..f77fba48e 100644 --- a/base/types/null_type.h +++ b/base/types/null_type.h @@ -23,9 +23,9 @@ namespace cel { class NullValue; -class NullType final : public base_internal::SimpleType { +class NullType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/opaque_type.h b/base/types/opaque_type.h index 3063a63c8..70dafde9e 100644 --- a/base/types/opaque_type.h +++ b/base/types/opaque_type.h @@ -28,7 +28,7 @@ namespace cel { class OpaqueType : public Type, public base_internal::HeapData { public: - static constexpr Kind kKind = Kind::kOpaque; + static constexpr TypeKind kKind = TypeKind::kOpaque; static bool Is(const Type& type) { return type.kind() == kKind; } @@ -39,7 +39,7 @@ class OpaqueType : public Type, public base_internal::HeapData { return static_cast(type); } - Kind kind() const { return kKind; } + TypeKind kind() const { return kKind; } virtual absl::string_view name() const = 0; diff --git a/base/types/string_type.h b/base/types/string_type.h index 3c292c88e..c786a719f 100644 --- a/base/types/string_type.h +++ b/base/types/string_type.h @@ -24,9 +24,9 @@ namespace cel { class StringValue; class StringWrapperType; -class StringType final : public base_internal::SimpleType { +class StringType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/struct_type.h b/base/types/struct_type.h index e02216600..3f8c1a53c 100644 --- a/base/types/struct_type.h +++ b/base/types/struct_type.h @@ -94,7 +94,7 @@ class StructType : public Type { absl::variant data_; }; - static constexpr Kind kKind = Kind::kStruct; + static constexpr TypeKind kKind = TypeKind::kStruct; static bool Is(const Type& type) { return type.kind() == kKind; } @@ -105,7 +105,7 @@ class StructType : public Type { return static_cast(type); } - Kind kind() const { return kKind; } + TypeKind kind() const { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; diff --git a/base/types/timestamp_type.h b/base/types/timestamp_type.h index 3042c694b..1dba44a5c 100644 --- a/base/types/timestamp_type.h +++ b/base/types/timestamp_type.h @@ -23,9 +23,10 @@ namespace cel { class TimestampValue; -class TimestampType final : public base_internal::SimpleType { +class TimestampType final + : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/type_type.h b/base/types/type_type.h index 2eb483b5d..54d8fac07 100644 --- a/base/types/type_type.h +++ b/base/types/type_type.h @@ -23,9 +23,9 @@ namespace cel { class TypeValue; -class TypeType final : public base_internal::SimpleType { +class TypeType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/uint_type.h b/base/types/uint_type.h index 62080b590..addefeefc 100644 --- a/base/types/uint_type.h +++ b/base/types/uint_type.h @@ -24,9 +24,9 @@ namespace cel { class UintValue; class UintWrapperType; -class UintType final : public base_internal::SimpleType { +class UintType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/unknown_type.h b/base/types/unknown_type.h index d917444ab..0a84dc026 100644 --- a/base/types/unknown_type.h +++ b/base/types/unknown_type.h @@ -23,9 +23,9 @@ namespace cel { class UnknownValue; -class UnknownType final : public base_internal::SimpleType { +class UnknownType final : public base_internal::SimpleType { private: - using Base = base_internal::SimpleType; + using Base = base_internal::SimpleType; public: using Base::kKind; diff --git a/base/types/wrapper_type.h b/base/types/wrapper_type.h index 1b74379e3..e9fec2ebb 100644 --- a/base/types/wrapper_type.h +++ b/base/types/wrapper_type.h @@ -50,9 +50,9 @@ class WrapperType : public Type, base_internal::InlineData { using Base = base_internal::InlineData; public: - static constexpr Kind kKind = Kind::kWrapper; + static constexpr TypeKind kKind = TypeKind::kWrapper; - static bool Is(const Type& type) { return type.kind() == Kind::kWrapper; } + static bool Is(const Type& type) { return type.kind() == TypeKind::kWrapper; } using Type::Is; @@ -61,7 +61,7 @@ class WrapperType : public Type, base_internal::InlineData { return static_cast(type); } - constexpr Kind kind() const { return kKind; } + constexpr TypeKind kind() const { return kKind; } absl::string_view name() const; @@ -105,7 +105,7 @@ class BoolWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kBool; + TypeKind::kBool; } using WrapperType::Is; @@ -142,7 +142,7 @@ class BytesWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kBytes; + TypeKind::kBytes; } using WrapperType::Is; @@ -179,7 +179,7 @@ class DoubleWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kDouble; + TypeKind::kDouble; } using WrapperType::Is; @@ -220,7 +220,7 @@ class IntWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kInt; + TypeKind::kInt; } using WrapperType::Is; @@ -261,7 +261,7 @@ class StringWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kString; + TypeKind::kString; } using WrapperType::Is; @@ -298,7 +298,7 @@ class UintWrapperType final : public WrapperType { static bool Is(const Type& type) { return WrapperType::Is(type) && static_cast(type).wrapped()->kind() == - Kind::kUint; + TypeKind::kUint; } using WrapperType::Is; diff --git a/base/value.h b/base/value.h index 3cf89933b..6ff9fa0ba 100644 --- a/base/value.h +++ b/base/value.h @@ -231,7 +231,7 @@ struct HandleTraits && template class SimpleValue : public Value, InlineData { public: - static constexpr Kind kKind = T::kKind; + static constexpr Kind kKind = static_cast(T::kKind); static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index d90311980..4ec7e041f 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -38,7 +38,7 @@ class ValueFactory; class BytesValue : public Value { public: - static constexpr Kind kKind = BytesType::kKind; + static constexpr Kind kKind = Kind::kBytes; static Handle Empty(ValueFactory& value_factory); diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 7770e31c3..89e8f004d 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -36,7 +36,7 @@ class ValueFactory; // EnumValue represents a single constant belonging to cel::EnumType. class EnumValue final : public Value, public base_internal::InlineData { public: - static constexpr Kind kKind = EnumType::kKind; + static constexpr Kind kKind = Kind::kEnum; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/error_value.h b/base/values/error_value.h index e7b1a8cef..0924a0a14 100644 --- a/base/values/error_value.h +++ b/base/values/error_value.h @@ -35,7 +35,7 @@ class ErrorValue final : public Value, public base_internal::InlineData { ABSL_ATTRIBUTE_PURE_FUNCTION static std::string DebugString( const absl::Status& value); - static constexpr Kind kKind = ErrorType::kKind; + static constexpr Kind kKind = Kind::kError; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/list_value.h b/base/values/list_value.h index 732c2e268..60097cee3 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -49,7 +49,7 @@ class ListValue : public Value { template using Builder = ListValueBuilder; - static constexpr Kind kKind = ListType::kKind; + static constexpr Kind kKind = Kind::kList; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/map_value.h b/base/values/map_value.h index ca75538ac..716b2de8e 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -51,7 +51,7 @@ class MapValue : public Value { template using Builder = MapValueBuilder; - static constexpr Kind kKind = MapType::kKind; + static constexpr Kind kKind = Kind::kMap; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/string_value.h b/base/values/string_value.h index f5eb1a419..69dd7b462 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -39,7 +39,7 @@ class ValueFactory; class StringValue : public Value { public: - static constexpr Kind kKind = StringType::kKind; + static constexpr Kind kKind = Kind::kString; static Handle Empty(ValueFactory& value_factory); diff --git a/base/values/type_value.h b/base/values/type_value.h index 09bf3f398..bb61f406e 100644 --- a/base/values/type_value.h +++ b/base/values/type_value.h @@ -31,7 +31,7 @@ namespace cel { class TypeValue : public Value { public: - static constexpr Kind kKind = TypeType::kKind; + static constexpr Kind kKind = Kind::kType; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/base/values/unknown_value.h b/base/values/unknown_value.h index 8f559a89e..eb34f30e3 100644 --- a/base/values/unknown_value.h +++ b/base/values/unknown_value.h @@ -29,7 +29,7 @@ namespace cel { class UnknownValue final : public Value, public base_internal::InlineData { public: - static constexpr Kind kKind = UnknownType::kKind; + static constexpr Kind kKind = Kind::kUnknown; static bool Is(const Value& value) { return value.kind() == kKind; } diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index abf4a4f0f..20801d6db 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -148,12 +148,11 @@ cc_library( deps = [ ":memory_manager", ":type", + "//base:data", "//base:handle", "//base:kind", "//base:memory", "//base:owner", - "//base:type", - "//base:value", "//eval/internal:errors", "//eval/internal:interop", "//eval/public:message_wrapper", diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 7d952c263..6ac6eac25 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -1552,9 +1552,9 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { ProtoType::Resolve(context.value_factory().type_manager(), *value_desc->enum_type())); switch (type->kind()) { - case Kind::kNullType: + case TypeKind::kNullType: return context.value_factory().GetNullValue(); - case Kind::kEnum: + case TypeKind::kEnum: return context.value_factory().CreateEnumValue( std::move(type).As(), proto_value.GetEnumValue()); @@ -1571,40 +1571,40 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { ProtoType::Resolve(context.value_factory().type_manager(), *value_desc->message_type())); switch (type->kind()) { - case Kind::kDuration: { + case TypeKind::kDuration: { CEL_ASSIGN_OR_RETURN( auto duration, protobuf_internal::AbslDurationFromDurationProto( proto_value.GetMessageValue())); return context.value_factory().CreateUncheckedDurationValue( duration); } - case Kind::kTimestamp: { + case TypeKind::kTimestamp: { CEL_ASSIGN_OR_RETURN(auto time, protobuf_internal::AbslTimeFromTimestampProto( proto_value.GetMessageValue())); return context.value_factory().CreateUncheckedTimestampValue(time); } - case Kind::kList: + case TypeKind::kList: // google.protobuf.ListValue return protobuf_internal::CreateBorrowedListValue( owner_from_this(), context.value_factory(), proto_value.GetMessageValue()); - case Kind::kMap: + case TypeKind::kMap: // google.protobuf.Struct return protobuf_internal::CreateBorrowedStruct( owner_from_this(), context.value_factory(), proto_value.GetMessageValue()); - case Kind::kDyn: + case TypeKind::kDyn: // google.protobuf.Value return protobuf_internal::CreateBorrowedValue( owner_from_this(), context.value_factory(), proto_value.GetMessageValue()); - case Kind::kAny: + case TypeKind::kAny: return ProtoValue::Create(context.value_factory(), proto_value.GetMessageValue()); - case Kind::kWrapper: + case TypeKind::kWrapper: switch (type->As().wrapped()->kind()) { - case Kind::kBool: { + case TypeKind::kBool: { // google.protobuf.BoolValue, mapped to CEL primitive bool type // for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1612,7 +1612,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_value.GetMessageValue())); return context.value_factory().CreateBoolValue(wrapped); } - case Kind::kBytes: { + case TypeKind::kBytes: { // google.protobuf.BytesValue, mapped to CEL primitive bytes // type for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1621,7 +1621,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { return context.value_factory().CreateBytesValue( std::move(wrapped)); } - case Kind::kDouble: { + case TypeKind::kDouble: { // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL // primitive double type for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1629,7 +1629,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_value.GetMessageValue())); return context.value_factory().CreateDoubleValue(wrapped); } - case Kind::kInt: { + case TypeKind::kInt: { // google.protobuf.{Int32Value,Int64Value}, mapped to CEL // primitive int type for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1637,7 +1637,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { proto_value.GetMessageValue())); return context.value_factory().CreateIntValue(wrapped); } - case Kind::kString: { + case TypeKind::kString: { // google.protobuf.StringValue, mapped to CEL primitive bytes // type for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1646,7 +1646,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { return context.value_factory().CreateUncheckedStringValue( std::move(wrapped)); } - case Kind::kUint: { + case TypeKind::kUint: { // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL // primitive uint type for map values. CEL_ASSIGN_OR_RETURN(auto wrapped, @@ -1657,7 +1657,7 @@ class ParsedProtoMapValue : public CEL_MAP_VALUE_CLASS { default: ABSL_UNREACHABLE(); } - case Kind::kStruct: + case TypeKind::kStruct: return context.value_factory() .CreateBorrowedStructValue< protobuf_internal::DynamicMemberParsedProtoStructValue>( @@ -2341,21 +2341,21 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: switch (field.type.As()->element()->kind()) { - case Kind::kDuration: + case TypeKind::kDuration: return context.value_factory() .CreateBorrowedListValue< ParsedProtoListValue>( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kTimestamp: + case TypeKind::kTimestamp: return context.value_factory() .CreateBorrowedListValue< ParsedProtoListValue>( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kList: + case TypeKind::kList: // google.protobuf.ListValue return context.value_factory() .CreateBorrowedListValue< @@ -2363,7 +2363,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kMap: + case TypeKind::kMap: // google.protobuf.Struct return context.value_factory() .CreateBorrowedListValue< @@ -2371,7 +2371,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kDyn: + case TypeKind::kDyn: // google.protobuf.Value. return context.value_factory() .CreateBorrowedListValue< @@ -2379,14 +2379,14 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kAny: + case TypeKind::kAny: return context.value_factory() .CreateBorrowedListValue< ParsedProtoListValue>( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kBool: + case TypeKind::kBool: // google.protobuf.BoolValue, mapped to CEL primitive bool type for // list elements. return context.value_factory() @@ -2395,7 +2395,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kBytes: + case TypeKind::kBytes: // google.protobuf.BytesValue, mapped to CEL primitive bytes type for // list elements. return context.value_factory() @@ -2404,7 +2404,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kDouble: + case TypeKind::kDouble: // google.protobuf.{FloatValue,DoubleValue}, mapped to CEL primitive // double type for list elements. return context.value_factory() @@ -2413,7 +2413,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kInt: + case TypeKind::kInt: // google.protobuf.{Int32Value,Int64Value}, mapped to CEL primitive // int type for list elements. return context.value_factory() @@ -2422,7 +2422,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kString: + case TypeKind::kString: // google.protobuf.StringValue, mapped to CEL primitive bytes type for // list elements. return context.value_factory() @@ -2431,7 +2431,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kUint: + case TypeKind::kUint: // google.protobuf.{UInt32Value,UInt64Value}, mapped to CEL primitive // uint type for list elements. return context.value_factory() @@ -2440,7 +2440,7 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( owner_from_this(), field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc)); - case Kind::kStruct: + case TypeKind::kStruct: return context.value_factory() .CreateBorrowedListValue< ParsedProtoListValue>( @@ -2458,13 +2458,13 @@ absl::StatusOr> ParsedProtoStructValue::GetRepeatedField( reflect.GetRepeatedFieldRef(value(), &field_desc)); case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type.As()->element()->kind()) { - case Kind::kNullType: + case TypeKind::kNullType: return context.value_factory() .CreateListValue>( field.type.As(), reflect.GetRepeatedFieldRef(value(), &field_desc) .size()); - case Kind::kEnum: + case TypeKind::kEnum: return context.value_factory() .CreateBorrowedListValue< ParsedProtoListValue>( @@ -2522,14 +2522,14 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: switch (field.type->kind()) { - case Kind::kDuration: { + case TypeKind::kDuration: { CEL_ASSIGN_OR_RETURN( auto duration, protobuf_internal::AbslDurationFromDurationProto( reflect.GetMessage(value(), &field_desc, type()->factory_))); return context.value_factory().CreateUncheckedDurationValue(duration); } - case Kind::kTimestamp: { + case TypeKind::kTimestamp: { CEL_ASSIGN_OR_RETURN( auto timestamp, protobuf_internal::AbslTimeFromTimestampProto( @@ -2537,39 +2537,39 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateUncheckedTimestampValue( timestamp); } - case Kind::kList: + case TypeKind::kList: // google.protobuf.ListValue return protobuf_internal::CreateBorrowedListValue( owner_from_this(), context.value_factory(), reflect.GetMessage(value(), &field_desc)); - case Kind::kMap: + case TypeKind::kMap: // google.protobuf.Struct return protobuf_internal::CreateBorrowedStruct( owner_from_this(), context.value_factory(), reflect.GetMessage(value(), &field_desc)); - case Kind::kDyn: + case TypeKind::kDyn: // google.protobuf.Value return protobuf_internal::CreateBorrowedValue( owner_from_this(), context.value_factory(), reflect.GetMessage(value(), &field_desc)); - case Kind::kAny: + case TypeKind::kAny: // google.protobuf.Any return ProtoValue::Create(context.value_factory(), reflect.GetMessage(value(), &field_desc)); - case Kind::kWrapper: { + case TypeKind::kWrapper: { if (context.unbox_null_wrapper_types() && !reflect.HasField(value(), &field_desc)) { return context.value_factory().GetNullValue(); } switch (field.type.As()->wrapped()->kind()) { - case Kind::kBool: { + case TypeKind::kBool: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapBoolValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); return context.value_factory().CreateBoolValue(wrapped); } - case Kind::kBytes: { + case TypeKind::kBytes: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapBytesValueProto(reflect.GetMessage( @@ -2577,21 +2577,21 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateBytesValue( std::move(wrapped)); } - case Kind::kDouble: { + case TypeKind::kDouble: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapDoubleValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); return context.value_factory().CreateDoubleValue(wrapped); } - case Kind::kInt: { + case TypeKind::kInt: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapIntValueProto(reflect.GetMessage( value(), &field_desc, type()->factory_))); return context.value_factory().CreateIntValue(wrapped); } - case Kind::kString: { + case TypeKind::kString: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapStringValueProto(reflect.GetMessage( @@ -2599,7 +2599,7 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( return context.value_factory().CreateUncheckedStringValue( std::move(wrapped)); } - case Kind::kUint: { + case TypeKind::kUint: { CEL_ASSIGN_OR_RETURN( auto wrapped, protobuf_internal::UnwrapUIntValueProto(reflect.GetMessage( @@ -2611,7 +2611,7 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( ABSL_UNREACHABLE(); } } - case Kind::kStruct: + case TypeKind::kStruct: return context.value_factory() .CreateBorrowedStructValue( owner_from_this(), field.type.As(), @@ -2625,9 +2625,9 @@ absl::StatusOr> ParsedProtoStructValue::GetSingularField( &field_desc); case google::protobuf::FieldDescriptor::TYPE_ENUM: switch (field.type->kind()) { - case Kind::kNullType: + case TypeKind::kNullType: return context.value_factory().GetNullValue(); - case Kind::kEnum: + case TypeKind::kEnum: return context.value_factory().CreateEnumValue( field.type.As(), reflect.GetEnumValue(value(), &field_desc)); diff --git a/extensions/protobuf/value.cc b/extensions/protobuf/value.cc index daf2c90f3..3dd6bf953 100644 --- a/extensions/protobuf/value.cc +++ b/extensions/protobuf/value.cc @@ -1507,10 +1507,10 @@ absl::StatusOr> ProtoValue::Create( CEL_ASSIGN_OR_RETURN( auto type, ProtoType::Resolve(value_factory.type_manager(), descriptor)); switch (type->kind()) { - case Kind::kNullType: + case TypeKind::kNullType: // google.protobuf.NullValue is an enum, which represents JSON null. return value_factory.GetNullValue(); - case Kind::kEnum: + case TypeKind::kEnum: return value_factory.CreateEnumValue(std::move(type).As(), value); default: @@ -1548,14 +1548,14 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, absl::NotFoundError(absl::StrCat("type not found: ", type_url))); } switch ((*type)->kind()) { - case Kind::kAny: + case TypeKind::kAny: ABSL_DCHECK(type_name == "google.protobuf.Any") << type_name; // google.protobuf.Any // // We refuse google.protobuf.Any wrapped in google.protobuf.Any. return absl::InvalidArgumentError( "refusing to unpack google.protobuf.Any to google.protobuf.Any"); - case Kind::kStruct: { + case TypeKind::kStruct: { if (!ProtoStructType::Is(**type)) { return absl::FailedPreconditionError( "google.protobuf.Any can only be unpacked to protocol " @@ -1577,15 +1577,15 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, } return ProtoStructValue::Create(value_factory, std::move(*proto)); } - case Kind::kWrapper: { + case TypeKind::kWrapper: { switch ((*type)->As().wrapped()->kind()) { - case Kind::kBool: { + case TypeKind::kBool: { // google.protobuf.BoolValue CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, proto); } - case Kind::kInt: { + case TypeKind::kInt: { // google.protobuf.{Int32Value,Int64Value} if (type_name == "google.protobuf.Int32Value") { CEL_ASSIGN_OR_RETURN( @@ -1598,7 +1598,7 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, return Create(value_factory, std::move(proto)); } } break; - case Kind::kUint: { + case TypeKind::kUint: { // google.protobuf.{UInt32Value,UInt64Value} if (type_name == "google.protobuf.UInt32Value") { CEL_ASSIGN_OR_RETURN( @@ -1611,7 +1611,7 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, return Create(value_factory, std::move(proto)); } } break; - case Kind::kDouble: { + case TypeKind::kDouble: { // google.protobuf.{FloatValue,DoubleValue} if (type_name == "google.protobuf.FloatValue") { CEL_ASSIGN_OR_RETURN( @@ -1624,13 +1624,13 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, return Create(value_factory, std::move(proto)); } } break; - case Kind::kBytes: { + case TypeKind::kBytes: { // google.protobuf.BytesValue CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, std::move(proto)); } - case Kind::kString: { + case TypeKind::kString: { // google.protobuf.StringValue CEL_ASSIGN_OR_RETURN( auto proto, UnpackTo(payload)); @@ -1640,35 +1640,35 @@ absl::StatusOr> ProtoValue::Create(ValueFactory& value_factory, ABSL_UNREACHABLE(); } } break; - case Kind::kList: { + case TypeKind::kList: { // google.protobuf.ListValue ABSL_DCHECK(type_name == "google.protobuf.ListValue") << type_name; CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, std::move(proto)); } - case Kind::kMap: { + case TypeKind::kMap: { // google.protobuf.Struct ABSL_DCHECK(type_name == "google.protobuf.Struct") << type_name; CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, std::move(proto)); } - case Kind::kDyn: { + case TypeKind::kDyn: { // google.protobuf.Value ABSL_DCHECK(type_name == "google.protobuf.Value") << type_name; CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, std::move(proto)); } - case Kind::kDuration: { + case TypeKind::kDuration: { // google.protobuf.Duration ABSL_DCHECK(type_name == "google.protobuf.Duration") << type_name; CEL_ASSIGN_OR_RETURN(auto proto, UnpackTo(payload)); return Create(value_factory, proto); } - case Kind::kTimestamp: { + case TypeKind::kTimestamp: { // google.protobuf.Timestamp ABSL_DCHECK(type_name == "google.protobuf.Timestamp") << type_name; CEL_ASSIGN_OR_RETURN(auto proto, From ebf8c934b933e26b5a9752b4f17cc897c72a34e6 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 23 May 2023 14:02:03 -0700 Subject: [PATCH 302/303] Migrate `Value::kind()` to `ValueKind` PiperOrigin-RevId: 534550418 --- base/internal/BUILD | 2 +- base/internal/data.h | 2 + base/internal/function_adapter.h | 4 +- base/kind.h | 16 +++ base/value.cc | 167 ++++++++++++++-------------- base/value.h | 16 +-- base/value_test.cc | 90 +++++++-------- base/values/bytes_value.h | 6 +- base/values/enum_value.h | 4 +- base/values/error_value.h | 4 +- base/values/list_value.h | 6 +- base/values/map_value.h | 4 +- base/values/map_value_builder.h | 16 +-- base/values/opaque_value.h | 6 +- base/values/string_value.h | 4 +- base/values/struct_value.h | 4 +- base/values/type_value.h | 4 +- base/values/unknown_value.h | 4 +- eval/compiler/BUILD | 2 +- eval/compiler/constant_folding.cc | 4 +- eval/eval/BUILD | 7 +- eval/eval/comprehension_step.cc | 6 +- eval/eval/container_access_step.cc | 40 +++---- eval/eval/function_step.cc | 13 ++- eval/eval/select_step.cc | 10 +- eval/internal/interop.cc | 42 ++++--- extensions/protobuf/struct_value.cc | 8 +- 27 files changed, 256 insertions(+), 235 deletions(-) diff --git a/base/internal/BUILD b/base/internal/BUILD index e2d356430..f53dbac66 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -152,9 +152,9 @@ cc_library( "function_adapter.h", ], deps = [ + "//base:data", "//base:handle", "//base:kind", - "//base:value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/base/internal/data.h b/base/internal/data.h index d796e0717..10ba98478 100644 --- a/base/internal/data.h +++ b/base/internal/data.h @@ -186,6 +186,8 @@ class HeapData /* : public Data */ { explicit HeapData(TypeKind kind) : HeapData(TypeKindToKind(kind)) {} + explicit HeapData(ValueKind kind) : HeapData(ValueKindToKind(kind)) {} + private: // Called by Arena-based memory managers to determine whether we actually need // our destructor called. Subclasses should override this if they want their diff --git a/base/internal/function_adapter.h b/base/internal/function_adapter.h index fc0624109..e149c8408 100644 --- a/base/internal/function_adapter.h +++ b/base/internal/function_adapter.h @@ -184,7 +184,7 @@ struct HandleToAdaptedVisitor { absl::Status operator()(const Handle** out) { if (!input->Is()) { return absl::InvalidArgumentError( - absl::StrCat("expected ", KindToString(T::kKind), " value")); + absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); } *out = &(input.As()); return absl::OkStatus(); @@ -194,7 +194,7 @@ struct HandleToAdaptedVisitor { absl::Status operator()(const T** out) { if (!input->Is()) { return absl::InvalidArgumentError( - absl::StrCat("expected ", KindToString(T::kKind), " value")); + absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); } *out = &(*input.As()); return absl::OkStatus(); diff --git a/base/kind.h b/base/kind.h index 085a7ee32..6f78d8596 100644 --- a/base/kind.h +++ b/base/kind.h @@ -225,6 +225,22 @@ constexpr ValueKind KindToValueKind(Kind kind) { static_cast>(kind)); } +constexpr ValueKind TypeKindToValueKind(TypeKind kind) { + ABSL_ASSERT(KindIsValueKind(TypeKindToKind(kind))); + return static_cast( + static_cast>(kind)); +} + +constexpr TypeKind ValueKindToTypeKind(ValueKind kind) { + ABSL_ASSERT(KindIsTypeKind(ValueKindToKind(kind))); + return static_cast( + static_cast>(kind)); +} + +static_assert(std::is_same_v, + std::underlying_type_t>, + "TypeKind and ValueKind must have the same underlying type"); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_KIND_H_ diff --git a/base/value.cc b/base/value.cc index 571c11606..e26b5f5f2 100644 --- a/base/value.cc +++ b/base/value.cc @@ -46,39 +46,39 @@ CEL_INTERNAL_VALUE_IMPL(Value); Handle Value::type() const { switch (kind()) { - case Kind::kNullType: + case ValueKind::kNullType: return static_cast(this)->type().As(); - case Kind::kError: + case ValueKind::kError: return static_cast(this)->type().As(); - case Kind::kType: + case ValueKind::kType: return static_cast(this)->type().As(); - case Kind::kBool: + case ValueKind::kBool: return static_cast(this)->type().As(); - case Kind::kInt: + case ValueKind::kInt: return static_cast(this)->type().As(); - case Kind::kUint: + case ValueKind::kUint: return static_cast(this)->type().As(); - case Kind::kDouble: + case ValueKind::kDouble: return static_cast(this)->type().As(); - case Kind::kString: + case ValueKind::kString: return static_cast(this)->type().As(); - case Kind::kBytes: + case ValueKind::kBytes: return static_cast(this)->type().As(); - case Kind::kEnum: + case ValueKind::kEnum: return static_cast(this)->type().As(); - case Kind::kDuration: + case ValueKind::kDuration: return static_cast(this)->type().As(); - case Kind::kTimestamp: + case ValueKind::kTimestamp: return static_cast(this)->type().As(); - case Kind::kList: + case ValueKind::kList: return static_cast(this)->type().As(); - case Kind::kMap: + case ValueKind::kMap: return static_cast(this)->type().As(); - case Kind::kStruct: + case ValueKind::kStruct: return static_cast(this)->type().As(); - case Kind::kUnknown: + case ValueKind::kUnknown: return static_cast(this)->type().As(); - case Kind::kOpaque: + case ValueKind::kOpaque: return static_cast(this)->type().As(); default: ABSL_UNREACHABLE(); @@ -87,39 +87,39 @@ Handle Value::type() const { std::string Value::DebugString() const { switch (kind()) { - case Kind::kNullType: + case ValueKind::kNullType: return static_cast(this)->DebugString(); - case Kind::kError: + case ValueKind::kError: return static_cast(this)->DebugString(); - case Kind::kType: + case ValueKind::kType: return static_cast(this)->DebugString(); - case Kind::kBool: + case ValueKind::kBool: return static_cast(this)->DebugString(); - case Kind::kInt: + case ValueKind::kInt: return static_cast(this)->DebugString(); - case Kind::kUint: + case ValueKind::kUint: return static_cast(this)->DebugString(); - case Kind::kDouble: + case ValueKind::kDouble: return static_cast(this)->DebugString(); - case Kind::kString: + case ValueKind::kString: return static_cast(this)->DebugString(); - case Kind::kBytes: + case ValueKind::kBytes: return static_cast(this)->DebugString(); - case Kind::kEnum: + case ValueKind::kEnum: return static_cast(this)->DebugString(); - case Kind::kDuration: + case ValueKind::kDuration: return static_cast(this)->DebugString(); - case Kind::kTimestamp: + case ValueKind::kTimestamp: return static_cast(this)->DebugString(); - case Kind::kList: + case ValueKind::kList: return static_cast(this)->DebugString(); - case Kind::kMap: + case ValueKind::kMap: return static_cast(this)->DebugString(); - case Kind::kStruct: + case ValueKind::kStruct: return static_cast(this)->DebugString(); - case Kind::kUnknown: + case ValueKind::kUnknown: return static_cast(this)->DebugString(); - case Kind::kOpaque: + case ValueKind::kOpaque: return static_cast(this)->DebugString(); default: ABSL_UNREACHABLE(); @@ -128,46 +128,46 @@ std::string Value::DebugString() const { namespace base_internal { -bool ValueHandle::Equals(const Value& lhs, const Value& rhs, Kind kind) { +bool ValueHandle::Equals(const Value& lhs, const Value& rhs, ValueKind kind) { switch (kind) { - case Kind::kNullType: + case ValueKind::kNullType: return true; - case Kind::kError: + case ValueKind::kError: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kType: + case ValueKind::kType: return static_cast(lhs).Equals( static_cast(rhs)); - case Kind::kBool: + case ValueKind::kBool: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kInt: + case ValueKind::kInt: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kUint: + case ValueKind::kUint: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kDouble: + case ValueKind::kDouble: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kString: + case ValueKind::kString: return static_cast(lhs).Equals( static_cast(rhs)); - case Kind::kBytes: + case ValueKind::kBytes: return static_cast(lhs).Equals( static_cast(rhs)); - case Kind::kEnum: + case ValueKind::kEnum: return static_cast(lhs).number() == static_cast(rhs).number() && static_cast(lhs).type() == static_cast(rhs).type(); - case Kind::kDuration: + case ValueKind::kDuration: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kTimestamp: + case ValueKind::kTimestamp: return static_cast(lhs).value() == static_cast(rhs).value(); - case Kind::kList: { + case ValueKind::kList: { bool stored_inline = Metadata::IsStoredInline(lhs); if (stored_inline != Metadata::IsStoredInline(rhs)) { return false; @@ -178,7 +178,7 @@ bool ValueHandle::Equals(const Value& lhs, const Value& rhs, Kind kind) { } return &lhs == &rhs; } - case Kind::kMap: { + case ValueKind::kMap: { bool stored_inline = Metadata::IsStoredInline(lhs); if (stored_inline != Metadata::IsStoredInline(rhs)) { return false; @@ -189,7 +189,7 @@ bool ValueHandle::Equals(const Value& lhs, const Value& rhs, Kind kind) { } return &lhs == &rhs; } - case Kind::kStruct: { + case ValueKind::kStruct: { bool stored_inline = Metadata::IsStoredInline(lhs); if (stored_inline != Metadata::IsStoredInline(rhs)) { return false; @@ -202,12 +202,12 @@ bool ValueHandle::Equals(const Value& lhs, const Value& rhs, Kind kind) { } return &lhs == &rhs; } - case Kind::kUnknown: + case ValueKind::kUnknown: return static_cast(lhs).attribute_set() == static_cast(rhs).attribute_set() && static_cast(lhs).function_result_set() == static_cast(rhs).function_result_set(); - case Kind::kOpaque: + case ValueKind::kOpaque: return &lhs == &rhs; default: ABSL_UNREACHABLE(); @@ -223,7 +223,7 @@ bool ValueHandle::Equals(const ValueHandle& other) const { if (self == nullptr || that == nullptr) { return false; } - Kind kind = self->kind(); + ValueKind kind = self->kind(); return kind == that->kind() && Equals(*self, *that, kind); } @@ -232,16 +232,16 @@ void ValueHandle::CopyFrom(const ValueHandle& other) { auto locality = other.data_.locality(); if (locality == DataLocality::kStoredInline) { if (!other.data_.IsTrivial()) { - switch (other.data_.kind_inline()) { - case Kind::kError: + switch (KindToValueKind(other.data_.kind_inline())) { + case ValueKind::kError: data_.ConstructInline( *static_cast(other.data_.get_inline())); return; - case Kind::kUnknown: + case ValueKind::kUnknown: data_.ConstructInline( *static_cast(other.data_.get_inline())); return; - case Kind::kString: + case ValueKind::kString: switch (other.data_.inline_variant()) { case InlinedStringValueVariant::kCord: data_.ConstructInline( @@ -255,7 +255,7 @@ void ValueHandle::CopyFrom(const ValueHandle& other) { break; } return; - case Kind::kBytes: + case ValueKind::kBytes: switch (other.data_.inline_variant()) { case InlinedBytesValueVariant::kCord: data_.ConstructInline( @@ -269,12 +269,12 @@ void ValueHandle::CopyFrom(const ValueHandle& other) { break; } return; - case Kind::kType: + case ValueKind::kType: data_.ConstructInline( *static_cast( other.data_.get_inline())); return; - case Kind::kEnum: + case ValueKind::kEnum: data_.ConstructInline( *static_cast(other.data_.get_inline())); return; @@ -297,18 +297,18 @@ void ValueHandle::MoveFrom(ValueHandle& other) { // data_ is currently uninitialized. if (other.data_.IsStoredInline()) { if (!other.data_.IsTrivial()) { - switch (other.data_.kind_inline()) { - case Kind::kError: + switch (KindToValueKind(other.data_.kind_inline())) { + case ValueKind::kError: data_.ConstructInline( std::move(*static_cast(other.data_.get_inline()))); other.data_.Destruct(); break; - case Kind::kUnknown: + case ValueKind::kUnknown: data_.ConstructInline( std::move(*static_cast(other.data_.get_inline()))); other.data_.Destruct(); break; - case Kind::kString: + case ValueKind::kString: switch (other.data_.inline_variant()) { case InlinedStringValueVariant::kCord: data_.ConstructInline( @@ -324,7 +324,7 @@ void ValueHandle::MoveFrom(ValueHandle& other) { break; } break; - case Kind::kBytes: + case ValueKind::kBytes: switch (other.data_.inline_variant()) { case InlinedBytesValueVariant::kCord: data_.ConstructInline( @@ -340,12 +340,12 @@ void ValueHandle::MoveFrom(ValueHandle& other) { break; } break; - case Kind::kType: + case ValueKind::kType: data_.ConstructInline(std::move( *static_cast(other.data_.get_inline()))); other.data_.Destruct(); break; - case Kind::kEnum: + case ValueKind::kEnum: data_.ConstructInline(std::move( *static_cast(other.data_.get_inline()))); other.data_.Destruct(); @@ -381,14 +381,14 @@ void ValueHandle::Destruct() { return; case DataLocality::kStoredInline: if (!data_.IsTrivial()) { - switch (data_.kind_inline()) { - case Kind::kError: + switch (KindToValueKind(data_.kind_inline())) { + case ValueKind::kError: data_.Destruct(); return; - case Kind::kUnknown: + case ValueKind::kUnknown: data_.Destruct(); return; - case Kind::kString: + case ValueKind::kString: switch (data_.inline_variant()) { case InlinedStringValueVariant::kCord: data_.Destruct(); @@ -398,7 +398,7 @@ void ValueHandle::Destruct() { break; } return; - case Kind::kBytes: + case ValueKind::kBytes: switch (data_.inline_variant()) { case InlinedBytesValueVariant::kCord: data_.Destruct(); @@ -408,10 +408,10 @@ void ValueHandle::Destruct() { break; } return; - case Kind::kType: + case ValueKind::kType: data_.Destruct(); return; - case Kind::kEnum: + case ValueKind::kEnum: data_.Destruct(); return; default: @@ -428,27 +428,28 @@ void ValueHandle::Destruct() { } void ValueHandle::Delete() const { - Delete(data_.kind_heap(), *static_cast(data_.get_heap())); + Delete(KindToValueKind(data_.kind_heap()), + *static_cast(data_.get_heap())); } -void ValueHandle::Delete(Kind kind, const Value& value) { +void ValueHandle::Delete(ValueKind kind, const Value& value) { switch (kind) { - case Kind::kList: + case ValueKind::kList: delete static_cast(&value); return; - case Kind::kMap: + case ValueKind::kMap: delete static_cast(&value); return; - case Kind::kStruct: + case ValueKind::kStruct: delete static_cast(&value); return; - case Kind::kString: + case ValueKind::kString: delete static_cast(&value); return; - case Kind::kBytes: + case ValueKind::kBytes: delete static_cast(&value); return; - case Kind::kOpaque: + case ValueKind::kOpaque: delete static_cast(&value); return; default: @@ -458,7 +459,7 @@ void ValueHandle::Delete(Kind kind, const Value& value) { void ValueMetadata::Unref(const Value& value) { if (Metadata::Unref(value)) { - ValueHandle::Delete(Metadata::KindHeap(value), value); + ValueHandle::Delete(KindToValueKind(Metadata::KindHeap(value)), value); } } diff --git a/base/value.h b/base/value.h index 6ff9fa0ba..d83c52c31 100644 --- a/base/value.h +++ b/base/value.h @@ -55,7 +55,9 @@ class Value : public base_internal::Data { // Returns the kind of the value. This is equivalent to `type().kind()` but // faster in many scenarios. As such it should be preferred when only the kind // is required. - Kind kind() const { return base_internal::Metadata::Kind(*this); } + ValueKind kind() const { + return KindToValueKind(base_internal::Metadata::Kind(*this)); + } // Returns the type of the value. If you only need the kind, prefer `kind()`. Handle type() const; @@ -181,7 +183,7 @@ class ValueHandle final { private: friend class ValueMetadata; - static bool Equals(const Value& lhs, const Value& rhs, Kind kind); + static bool Equals(const Value& lhs, const Value& rhs, ValueKind kind); void CopyFrom(const ValueHandle& other); @@ -203,7 +205,7 @@ class ValueHandle final { void Delete() const; - static void Delete(Kind kind, const Value& value); + static void Delete(ValueKind kind, const Value& value); AnyValue data_; }; @@ -231,7 +233,7 @@ struct HandleTraits && template class SimpleValue : public Value, InlineData { public: - static constexpr Kind kKind = static_cast(T::kKind); + static constexpr ValueKind kKind = TypeKindToValueKind(T::kKind); static bool Is(const Value& value) { return value.kind() == kKind; } @@ -242,7 +244,7 @@ class SimpleValue : public Value, InlineData { SimpleValue& operator=(const SimpleValue&) = default; SimpleValue& operator=(SimpleValue&&) = default; - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } const Handle& type() const { return T::Get(); } @@ -267,7 +269,7 @@ class SimpleValue : public Value, InlineData { template <> class SimpleValue : public Value, InlineData { public: - static constexpr Kind kKind = Kind::kNullType; + static constexpr ValueKind kKind = ValueKind::kNullType; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -278,7 +280,7 @@ class SimpleValue : public Value, InlineData { SimpleValue& operator=(const SimpleValue&) = default; SimpleValue& operator=(SimpleValue&&) = default; - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } const Handle& type() const { return NullType::Get(); } diff --git a/base/value_test.cc b/base/value_test.cc index 5a8874431..69602c662 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -753,7 +753,7 @@ TEST_P(ValueTest, Bool) { EXPECT_FALSE(false_value->Is()); EXPECT_EQ(false_value, false_value); EXPECT_EQ(false_value, value_factory.CreateBoolValue(false)); - EXPECT_EQ(false_value->kind(), Kind::kBool); + EXPECT_EQ(false_value->kind(), ValueKind::kBool); EXPECT_EQ(false_value->type(), type_factory.GetBoolType()); EXPECT_FALSE(false_value->value()); @@ -762,7 +762,7 @@ TEST_P(ValueTest, Bool) { EXPECT_FALSE(true_value->Is()); EXPECT_EQ(true_value, true_value); EXPECT_EQ(true_value, value_factory.CreateBoolValue(true)); - EXPECT_EQ(true_value->kind(), Kind::kBool); + EXPECT_EQ(true_value->kind(), ValueKind::kBool); EXPECT_EQ(true_value->type(), type_factory.GetBoolType()); EXPECT_TRUE(true_value->value()); @@ -779,7 +779,7 @@ TEST_P(ValueTest, Int) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, value_factory.CreateIntValue(0)); - EXPECT_EQ(zero_value->kind(), Kind::kInt); + EXPECT_EQ(zero_value->kind(), ValueKind::kInt); EXPECT_EQ(zero_value->type(), type_factory.GetIntType()); EXPECT_EQ(zero_value->value(), 0); @@ -788,7 +788,7 @@ TEST_P(ValueTest, Int) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, value_factory.CreateIntValue(1)); - EXPECT_EQ(one_value->kind(), Kind::kInt); + EXPECT_EQ(one_value->kind(), ValueKind::kInt); EXPECT_EQ(one_value->type(), type_factory.GetIntType()); EXPECT_EQ(one_value->value(), 1); @@ -805,7 +805,7 @@ TEST_P(ValueTest, Uint) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, value_factory.CreateUintValue(0)); - EXPECT_EQ(zero_value->kind(), Kind::kUint); + EXPECT_EQ(zero_value->kind(), ValueKind::kUint); EXPECT_EQ(zero_value->type(), type_factory.GetUintType()); EXPECT_EQ(zero_value->value(), 0); @@ -814,7 +814,7 @@ TEST_P(ValueTest, Uint) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, value_factory.CreateUintValue(1)); - EXPECT_EQ(one_value->kind(), Kind::kUint); + EXPECT_EQ(one_value->kind(), ValueKind::kUint); EXPECT_EQ(one_value->type(), type_factory.GetUintType()); EXPECT_EQ(one_value->value(), 1); @@ -831,7 +831,7 @@ TEST_P(ValueTest, Double) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, value_factory.CreateDoubleValue(0.0)); - EXPECT_EQ(zero_value->kind(), Kind::kDouble); + EXPECT_EQ(zero_value->kind(), ValueKind::kDouble); EXPECT_EQ(zero_value->type(), type_factory.GetDoubleType()); EXPECT_EQ(zero_value->value(), 0.0); @@ -840,7 +840,7 @@ TEST_P(ValueTest, Double) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, value_factory.CreateDoubleValue(1.0)); - EXPECT_EQ(one_value->kind(), Kind::kDouble); + EXPECT_EQ(one_value->kind(), ValueKind::kDouble); EXPECT_EQ(one_value->type(), type_factory.GetDoubleType()); EXPECT_EQ(one_value->value(), 1.0); @@ -859,7 +859,7 @@ TEST_P(ValueTest, Duration) { EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateDurationValue(absl::ZeroDuration()))); - EXPECT_EQ(zero_value->kind(), Kind::kDuration); + EXPECT_EQ(zero_value->kind(), ValueKind::kDuration); EXPECT_EQ(zero_value->type(), type_factory.GetDurationType()); EXPECT_EQ(zero_value->value(), absl::ZeroDuration()); @@ -868,7 +868,7 @@ TEST_P(ValueTest, Duration) { EXPECT_TRUE(one_value->Is()); EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value->kind(), Kind::kDuration); + EXPECT_EQ(one_value->kind(), ValueKind::kDuration); EXPECT_EQ(one_value->type(), type_factory.GetDurationType()); EXPECT_EQ(one_value->value(), absl::ZeroDuration() + absl::Nanoseconds(1)); @@ -889,7 +889,7 @@ TEST_P(ValueTest, Timestamp) { EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateTimestampValue(absl::UnixEpoch()))); - EXPECT_EQ(zero_value->kind(), Kind::kTimestamp); + EXPECT_EQ(zero_value->kind(), ValueKind::kTimestamp); EXPECT_EQ(zero_value->type(), type_factory.GetTimestampType()); EXPECT_EQ(zero_value->value(), absl::UnixEpoch()); @@ -898,7 +898,7 @@ TEST_P(ValueTest, Timestamp) { EXPECT_TRUE(one_value->Is()); EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value->kind(), Kind::kTimestamp); + EXPECT_EQ(one_value->kind(), ValueKind::kTimestamp); EXPECT_EQ(one_value->type(), type_factory.GetTimestampType()); EXPECT_EQ(one_value->value(), absl::UnixEpoch() + absl::Nanoseconds(1)); @@ -918,7 +918,7 @@ TEST_P(ValueTest, BytesFromString) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue(std::string("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->kind(), ValueKind::kBytes); EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -927,7 +927,7 @@ TEST_P(ValueTest, BytesFromString) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue(std::string("1")))); - EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->kind(), ValueKind::kBytes); EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -946,7 +946,7 @@ TEST_P(ValueTest, BytesFromStringView) { EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue(absl::string_view("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->kind(), ValueKind::kBytes); EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -956,7 +956,7 @@ TEST_P(ValueTest, BytesFromStringView) { EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue(absl::string_view("1")))); - EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->kind(), ValueKind::kBytes); EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -973,7 +973,7 @@ TEST_P(ValueTest, BytesFromCord) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue(absl::Cord("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->kind(), ValueKind::kBytes); EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); EXPECT_EQ(zero_value->ToCord(), "0"); @@ -982,7 +982,7 @@ TEST_P(ValueTest, BytesFromCord) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue(absl::Cord("1")))); - EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->kind(), ValueKind::kBytes); EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); EXPECT_EQ(one_value->ToCord(), "1"); @@ -999,7 +999,7 @@ TEST_P(ValueTest, BytesFromLiteral) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue("0"))); - EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->kind(), ValueKind::kBytes); EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1008,7 +1008,7 @@ TEST_P(ValueTest, BytesFromLiteral) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue("1"))); - EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->kind(), ValueKind::kBytes); EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1025,7 +1025,7 @@ TEST_P(ValueTest, BytesFromExternal) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue("0", []() {}))); - EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->kind(), ValueKind::kBytes); EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1034,7 +1034,7 @@ TEST_P(ValueTest, BytesFromExternal) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue("1", []() {}))); - EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->kind(), ValueKind::kBytes); EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1052,7 +1052,7 @@ TEST_P(ValueTest, StringFromString) { EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue(std::string("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->kind(), ValueKind::kString); EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1061,7 +1061,7 @@ TEST_P(ValueTest, StringFromString) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateStringValue(std::string("1")))); - EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->kind(), ValueKind::kString); EXPECT_EQ(one_value->type(), type_factory.GetStringType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1080,7 +1080,7 @@ TEST_P(ValueTest, StringFromStringView) { EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue(absl::string_view("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->kind(), ValueKind::kString); EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1091,7 +1091,7 @@ TEST_P(ValueTest, StringFromStringView) { EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateStringValue(absl::string_view("1")))); - EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->kind(), ValueKind::kString); EXPECT_EQ(one_value->type(), type_factory.GetStringType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1108,7 +1108,7 @@ TEST_P(ValueTest, StringFromCord) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue(absl::Cord("0")))); - EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->kind(), ValueKind::kString); EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); EXPECT_EQ(zero_value->ToCord(), "0"); @@ -1117,7 +1117,7 @@ TEST_P(ValueTest, StringFromCord) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateStringValue(absl::Cord("1")))); - EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->kind(), ValueKind::kString); EXPECT_EQ(one_value->type(), type_factory.GetStringType()); EXPECT_EQ(one_value->ToCord(), "1"); @@ -1134,7 +1134,7 @@ TEST_P(ValueTest, StringFromLiteral) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue("0"))); - EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->kind(), ValueKind::kString); EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1143,7 +1143,7 @@ TEST_P(ValueTest, StringFromLiteral) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateStringValue("1"))); - EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->kind(), ValueKind::kString); EXPECT_EQ(one_value->type(), type_factory.GetStringType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1160,7 +1160,7 @@ TEST_P(ValueTest, StringFromExternal) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue("0", []() {}))); - EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->kind(), ValueKind::kString); EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); EXPECT_EQ(zero_value->ToString(), "0"); @@ -1169,7 +1169,7 @@ TEST_P(ValueTest, StringFromExternal) { EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateStringValue("1", []() {}))); - EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->kind(), ValueKind::kString); EXPECT_EQ(one_value->type(), type_factory.GetStringType()); EXPECT_EQ(one_value->ToString(), "1"); @@ -1187,7 +1187,7 @@ TEST_P(ValueTest, Type) { EXPECT_EQ(null_value, null_value); EXPECT_EQ(null_value, value_factory.CreateTypeValue(type_factory.GetNullType())); - EXPECT_EQ(null_value->kind(), Kind::kType); + EXPECT_EQ(null_value->kind(), ValueKind::kType); EXPECT_EQ(null_value->type(), type_factory.GetTypeType()); EXPECT_EQ(null_value->name(), "null_type"); @@ -1197,7 +1197,7 @@ TEST_P(ValueTest, Type) { EXPECT_EQ(int_value, int_value); EXPECT_EQ(int_value, value_factory.CreateTypeValue(type_factory.GetIntType())); - EXPECT_EQ(int_value->kind(), Kind::kType); + EXPECT_EQ(int_value->kind(), ValueKind::kType); EXPECT_EQ(int_value->type(), type_factory.GetTypeType()); EXPECT_EQ(int_value->name(), "int"); @@ -1214,7 +1214,7 @@ TEST_P(ValueTest, Unknown) { EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); EXPECT_EQ(zero_value, value_factory.CreateUnknownValue()); - EXPECT_EQ(zero_value->kind(), Kind::kUnknown); + EXPECT_EQ(zero_value->kind(), ValueKind::kUnknown); EXPECT_EQ(zero_value->type(), type_factory.GetUnknownType()); } @@ -1229,7 +1229,7 @@ TEST_P(ValueTest, Optional) { EXPECT_TRUE(none_optional->Is()); EXPECT_FALSE(none_optional->Is()); EXPECT_EQ(none_optional, none_optional); - EXPECT_EQ(none_optional->kind(), Kind::kOpaque); + EXPECT_EQ(none_optional->kind(), ValueKind::kOpaque); ASSERT_OK_AND_ASSIGN(auto optional_type, type_factory.CreateOptionalType( type_factory.GetStringType())); EXPECT_EQ(none_optional->type(), optional_type); @@ -1243,7 +1243,7 @@ TEST_P(ValueTest, Optional) { EXPECT_TRUE(full_optional->Is()); EXPECT_FALSE(full_optional->Is()); EXPECT_EQ(full_optional, full_optional); - EXPECT_EQ(full_optional->kind(), Kind::kOpaque); + EXPECT_EQ(full_optional->kind(), ValueKind::kOpaque); EXPECT_EQ(full_optional->type(), optional_type); EXPECT_TRUE(full_optional->has_value()); EXPECT_EQ(full_optional->value(), value_factory.GetStringValue()); @@ -2015,7 +2015,7 @@ TEST_P(ValueTest, Enum) { EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(value_factory.CreateEnumValue(enum_type, "VALUE1"))); - EXPECT_EQ(one_value->kind(), Kind::kEnum); + EXPECT_EQ(one_value->kind(), ValueKind::kEnum); EXPECT_EQ(one_value->type(), enum_type); EXPECT_EQ(one_value->name(), "VALUE1"); EXPECT_EQ(one_value->number(), 1); @@ -2025,7 +2025,7 @@ TEST_P(ValueTest, Enum) { EXPECT_TRUE(two_value->Is()); EXPECT_FALSE(two_value->Is()); EXPECT_EQ(two_value, two_value); - EXPECT_EQ(two_value->kind(), Kind::kEnum); + EXPECT_EQ(two_value->kind(), ValueKind::kEnum); EXPECT_EQ(two_value->type(), enum_type); EXPECT_EQ(two_value->name(), "VALUE2"); EXPECT_EQ(two_value->number(), 2); @@ -2085,7 +2085,7 @@ TEST_P(ValueTest, Struct) { EXPECT_TRUE(zero_value->Is()); EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value->kind(), Kind::kStruct); + EXPECT_EQ(zero_value->kind(), ValueKind::kStruct); EXPECT_EQ(zero_value->type(), struct_type); EXPECT_EQ(zero_value.As()->value(), TestStruct{}); @@ -2096,7 +2096,7 @@ TEST_P(ValueTest, Struct) { EXPECT_TRUE(one_value->Is()); EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value->kind(), Kind::kStruct); + EXPECT_EQ(one_value->kind(), ValueKind::kStruct); EXPECT_EQ(one_value->type(), struct_type); EXPECT_EQ(one_value.As()->value(), (TestStruct{true, 1, 1, 1.0})); @@ -2185,7 +2185,7 @@ TEST_P(ValueTest, List) { EXPECT_TRUE(zero_value->Is()); EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value->kind(), Kind::kList); + EXPECT_EQ(zero_value->kind(), ValueKind::kList); EXPECT_EQ(zero_value->type(), list_type); EXPECT_EQ(zero_value.As()->value(), std::vector{}); @@ -2196,7 +2196,7 @@ TEST_P(ValueTest, List) { EXPECT_TRUE(one_value->Is()); EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value->kind(), Kind::kList); + EXPECT_EQ(one_value->kind(), ValueKind::kList); EXPECT_EQ(one_value->type(), list_type); EXPECT_EQ(one_value.As()->value(), std::vector{1}); @@ -2311,7 +2311,7 @@ TEST_P(ValueTest, Map) { EXPECT_TRUE(zero_value->Is()); EXPECT_FALSE(zero_value->Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value->kind(), Kind::kMap); + EXPECT_EQ(zero_value->kind(), ValueKind::kMap); EXPECT_EQ(zero_value->type(), map_type); EXPECT_EQ(zero_value.As()->value(), (std::map{})); @@ -2324,7 +2324,7 @@ TEST_P(ValueTest, Map) { EXPECT_TRUE(one_value->Is()); EXPECT_FALSE(one_value->Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value->kind(), Kind::kMap); + EXPECT_EQ(one_value->kind(), ValueKind::kMap); EXPECT_EQ(one_value->type(), map_type); EXPECT_EQ(one_value.As()->value(), (std::map{{"foo", 1}})); diff --git a/base/values/bytes_value.h b/base/values/bytes_value.h index 4ec7e041f..a896abb98 100644 --- a/base/values/bytes_value.h +++ b/base/values/bytes_value.h @@ -38,7 +38,7 @@ class ValueFactory; class BytesValue : public Value { public: - static constexpr Kind kKind = Kind::kBytes; + static constexpr ValueKind kKind = ValueKind::kBytes; static Handle Empty(ValueFactory& value_factory); @@ -65,7 +65,7 @@ class BytesValue : public Value { ABSL_ATTRIBUTE_PURE_FUNCTION static std::string DebugString( const absl::Cord& value); - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const { return BytesType::Get(); } @@ -139,7 +139,7 @@ class InlinedCordBytesValue final : public BytesValue, public InlineData { }; // Implementation of BytesValue that is stored inlined within a handle. This -// class is inheritently unsafe and care should be taken when using it. +// class is inherently unsafe and care should be taken when using it. class InlinedStringViewBytesValue final : public BytesValue, public InlineData { private: friend class BytesValue; diff --git a/base/values/enum_value.h b/base/values/enum_value.h index 89e8f004d..a5a51c49e 100644 --- a/base/values/enum_value.h +++ b/base/values/enum_value.h @@ -36,7 +36,7 @@ class ValueFactory; // EnumValue represents a single constant belonging to cel::EnumType. class EnumValue final : public Value, public base_internal::InlineData { public: - static constexpr Kind kKind = Kind::kEnum; + static constexpr ValueKind kKind = ValueKind::kEnum; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -55,7 +55,7 @@ class EnumValue final : public Value, public base_internal::InlineData { using ConstantId = EnumType::ConstantId; - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } const Handle& type() const { return type_; } diff --git a/base/values/error_value.h b/base/values/error_value.h index 0924a0a14..2b1e3ad47 100644 --- a/base/values/error_value.h +++ b/base/values/error_value.h @@ -35,7 +35,7 @@ class ErrorValue final : public Value, public base_internal::InlineData { ABSL_ATTRIBUTE_PURE_FUNCTION static std::string DebugString( const absl::Status& value); - static constexpr Kind kKind = Kind::kError; + static constexpr ValueKind kKind = ValueKind::kError; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -47,7 +47,7 @@ class ErrorValue final : public Value, public base_internal::InlineData { return static_cast(value); } - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const { return ErrorType::Get(); } diff --git a/base/values/list_value.h b/base/values/list_value.h index 60097cee3..3f706fcff 100644 --- a/base/values/list_value.h +++ b/base/values/list_value.h @@ -49,7 +49,7 @@ class ListValue : public Value { template using Builder = ListValueBuilder; - static constexpr Kind kKind = Kind::kList; + static constexpr ValueKind kKind = ValueKind::kList; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -61,11 +61,11 @@ class ListValue : public Value { return static_cast(value); } - // TODO(uncreated-issue/10): implement iterators so we can have cheap concated lists + // TODO(uncreated-issue/10): implement iterators so we can have cheap concat lists Handle type() const; - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } std::string DebugString() const; diff --git a/base/values/map_value.h b/base/values/map_value.h index 716b2de8e..716251cc9 100644 --- a/base/values/map_value.h +++ b/base/values/map_value.h @@ -51,7 +51,7 @@ class MapValue : public Value { template using Builder = MapValueBuilder; - static constexpr Kind kKind = Kind::kMap; + static constexpr ValueKind kKind = ValueKind::kMap; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -63,7 +63,7 @@ class MapValue : public Value { return static_cast(value); } - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const; diff --git a/base/values/map_value_builder.h b/base/values/map_value_builder.h index a9ee14698..a9ae154f2 100644 --- a/base/values/map_value_builder.h +++ b/base/values/map_value_builder.h @@ -113,13 +113,13 @@ template <> struct MapKeyHasher> { inline size_t operator()(const Handle& key) const { switch (key->kind()) { - case Kind::kBool: + case ValueKind::kBool: return absl::Hash{}(*key.As()); - case Kind::kInt: + case ValueKind::kInt: return absl::Hash{}(*key.As()); - case Kind::kUint: + case ValueKind::kUint: return absl::Hash{}(*key.As()); - case Kind::kString: + case ValueKind::kString: return absl::Hash{}(*key.As()); default: ABSL_UNREACHABLE(); @@ -147,13 +147,13 @@ struct MapKeyEqualer> { const Handle& rhs) const { ABSL_ASSERT(lhs->kind() == rhs->kind()); switch (lhs->kind()) { - case Kind::kBool: + case ValueKind::kBool: return *lhs.As() == *rhs.As(); - case Kind::kInt: + case ValueKind::kInt: return *lhs.As() == *rhs.As(); - case Kind::kUint: + case ValueKind::kUint: return *lhs.As() == *rhs.As(); - case Kind::kString: + case ValueKind::kString: return *lhs.As() == *rhs.As(); default: ABSL_UNREACHABLE(); diff --git a/base/values/opaque_value.h b/base/values/opaque_value.h index 61b589264..bfeb0d972 100644 --- a/base/values/opaque_value.h +++ b/base/values/opaque_value.h @@ -28,9 +28,9 @@ namespace cel { class OpaqueValue : public Value, public base_internal::HeapData { public: - static constexpr Kind kKind = Kind::kOpaque; + static constexpr ValueKind kKind = ValueKind::kOpaque; - static bool Is(const Value& value) { return value.kind() == Kind::kOpaque; } + static bool Is(const Value& value) { return value.kind() == kKind; } using Value::Is; @@ -40,6 +40,8 @@ class OpaqueValue : public Value, public base_internal::HeapData { return static_cast(value); } + constexpr ValueKind kind() const { return kKind; } + const Handle& type() const { return type_; } virtual std::string DebugString() const = 0; diff --git a/base/values/string_value.h b/base/values/string_value.h index 69dd7b462..f87e1b593 100644 --- a/base/values/string_value.h +++ b/base/values/string_value.h @@ -39,7 +39,7 @@ class ValueFactory; class StringValue : public Value { public: - static constexpr Kind kKind = Kind::kString; + static constexpr ValueKind kKind = ValueKind::kString; static Handle Empty(ValueFactory& value_factory); @@ -66,7 +66,7 @@ class StringValue : public Value { ABSL_ATTRIBUTE_PURE_FUNCTION static std::string DebugString( const absl::Cord& value); - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const { return StringType::Get(); } diff --git a/base/values/struct_value.h b/base/values/struct_value.h index 9590ff9e0..671e9f107 100644 --- a/base/values/struct_value.h +++ b/base/values/struct_value.h @@ -49,7 +49,7 @@ class StructValueBuilderInterface; // StructValue represents an instance of cel::StructType. class StructValue : public Value { public: - static constexpr Kind kKind = Kind::kStruct; + static constexpr ValueKind kKind = ValueKind::kStruct; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -63,7 +63,7 @@ class StructValue : public Value { using FieldId = StructType::FieldId; - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const; diff --git a/base/values/type_value.h b/base/values/type_value.h index bb61f406e..afe59504d 100644 --- a/base/values/type_value.h +++ b/base/values/type_value.h @@ -31,7 +31,7 @@ namespace cel { class TypeValue : public Value { public: - static constexpr Kind kKind = Kind::kType; + static constexpr ValueKind kKind = ValueKind::kType; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -43,7 +43,7 @@ class TypeValue : public Value { return static_cast(value); } - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } Handle type() const { return TypeType::Get(); } diff --git a/base/values/unknown_value.h b/base/values/unknown_value.h index eb34f30e3..0fab877c1 100644 --- a/base/values/unknown_value.h +++ b/base/values/unknown_value.h @@ -29,7 +29,7 @@ namespace cel { class UnknownValue final : public Value, public base_internal::InlineData { public: - static constexpr Kind kKind = Kind::kUnknown; + static constexpr ValueKind kKind = ValueKind::kUnknown; static bool Is(const Value& value) { return value.kind() == kKind; } @@ -41,7 +41,7 @@ class UnknownValue final : public Value, public base_internal::InlineData { return static_cast(value); } - constexpr Kind kind() const { return kKind; } + constexpr ValueKind kind() const { return kKind; } const Handle& type() const { return UnknownType::Get(); } diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index b680b68c7..ceb6093b6 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -189,10 +189,10 @@ cc_library( ":flat_expr_builder_extensions", ":resolver", "//base:ast_internal", + "//base:data", "//base:function", "//base:handle", "//base:kind", - "//base:value", "//base/internal:ast_impl", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 5b629254f..d32de41ec 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -226,11 +226,11 @@ class ConstantFoldingTransform { arg_kinds.reserve(arg_size); if (receiver_style) { arg_values.push_back(transform_.RemoveConstant(call_expr.target())); - arg_kinds.push_back(arg_values.back()->kind()); + arg_kinds.push_back(ValueKindToKind(arg_values.back()->kind())); } for (int i = 0; i < arg_num; i++) { arg_values.push_back(transform_.RemoveConstant(call_expr.args()[i])); - arg_kinds.push_back(arg_values.back()->kind()); + arg_kinds.push_back(ValueKindToKind(arg_values.back()->kind())); } // compute function overload diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 43df77339..d0b470049 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -121,9 +121,9 @@ cc_library( ":evaluator_core", ":expression_step_base", "//base:attributes", + "//base:data", "//base:kind", "//base:memory", - "//base:value", "//eval/internal:errors", "//eval/internal:interop", "//eval/public:cel_number", @@ -190,11 +190,11 @@ cc_library( ":attribute_trail", ":evaluator_core", ":expression_step_base", + "//base:data", "//base:function", "//base:function_descriptor", "//base:handle", "//base:kind", - "//base:value", "//eval/internal:errors", "//eval/internal:interop", "//eval/public:cel_function", @@ -228,10 +228,9 @@ cc_library( ":evaluator_core", ":expression_step_base", "//base:ast_internal", + "//base:data", "//base:handle", "//base:memory", - "//base:type", - "//base:value", "//eval/internal:errors", "//eval/internal:interop", "//eval/public:cel_options", diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 9410cbc42..302017588 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -112,9 +112,9 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { // Get the current index off the stack. const auto& current_index_value = state[POS_CURRENT_INDEX]; if (!current_index_value->Is()) { - return absl::InternalError( - absl::StrCat("ComprehensionNextStep: want int64_t, got ", - CelValue::TypeName(current_index_value->kind()))); + return absl::InternalError(absl::StrCat( + "ComprehensionNextStep: want int64_t, got ", + CelValue::TypeName(ValueKindToKind(current_index_value->kind())))); } CEL_RETURN_IF_ERROR(frame->IncrementIterations()); diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index e1b3edd66..d8e174e6f 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -37,13 +37,13 @@ using ::cel::BoolValue; using ::cel::DoubleValue; using ::cel::Handle; using ::cel::IntValue; -using ::cel::Kind; -using ::cel::KindToString; using ::cel::ListValue; using ::cel::MapValue; using ::cel::StringValue; using ::cel::UintValue; using ::cel::Value; +using ::cel::ValueKind; +using ::cel::ValueKindToString; using ::cel::extensions::ProtoMemoryManager; using ::cel::interop_internal::CreateErrorValueFromView; using ::cel::interop_internal::CreateIntValue; @@ -79,11 +79,11 @@ class ContainerAccessStep : public ExpressionStepBase { absl::optional CelNumberFromValue(const Handle& value) { switch (value->kind()) { - case Kind::kInt64: + case ValueKind::kInt64: return CelNumber::FromInt64(value.As()->value()); - case Kind::kUint64: + case ValueKind::kUint64: return CelNumber::FromUint64(value.As()->value()); - case Kind::kDouble: + case ValueKind::kDouble: return CelNumber::FromDouble(value.As()->value()); default: return absl::nullopt; @@ -91,28 +91,28 @@ absl::optional CelNumberFromValue(const Handle& value) { } absl::Status CheckMapKeyType(const Handle& key) { - Kind kind = key->kind(); + ValueKind kind = key->kind(); switch (kind) { - case Kind::kString: - case Kind::kInt64: - case Kind::kUint64: - case Kind::kBool: + case ValueKind::kString: + case ValueKind::kInt64: + case ValueKind::kUint64: + case ValueKind::kBool: return absl::OkStatus(); default: - return absl::InvalidArgumentError( - absl::StrCat("Invalid map key type: '", KindToString(kind), "'")); + return absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", ValueKindToString(kind), "'")); } } AttributeQualifier AttributeQualifierFromValue(const Handle& v) { switch (v->kind()) { - case Kind::kString: + case ValueKind::kString: return AttributeQualifier::OfString(v.As()->ToString()); - case Kind::kInt64: + case ValueKind::kInt64: return AttributeQualifier::OfInt(v.As()->value()); - case Kind::kUint64: + case ValueKind::kUint64: return AttributeQualifier::OfUint(v.As()->value()); - case Kind::kBool: + case ValueKind::kBool: return AttributeQualifier::OfBool(v.As()->value()); default: // Non-matching qualifier. @@ -199,7 +199,7 @@ absl::StatusOr> ContainerAccessStep::LookupInList( return absl::UnknownError( absl::StrCat("Index error: expected integer type, got ", - CelValue::TypeName(key->kind()))); + CelValue::TypeName(ValueKindToKind(key->kind())))); } ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( @@ -244,7 +244,7 @@ ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( // Select steps can be applied to either maps or messages switch (container->kind()) { - case Kind::kMap: { + case ValueKind::kMap: { auto result = LookupInMap(container.As(), key, frame); if (!result.ok()) { return {CreateErrorValueFromView(Arena::Create( @@ -253,7 +253,7 @@ ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( } return {std::move(result).value(), std::move(trail)}; } - case CelValue::Type::kList: { + case ValueKind::kList: { auto result = LookupInList(container.As(), key, frame); if (!result.ok()) { return {CreateErrorValueFromView(Arena::Create( @@ -266,7 +266,7 @@ ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( return {CreateErrorValueFromView(Arena::Create( arena, absl::StatusCode::kInvalidArgument, absl::StrCat("Invalid container type: '", - KindToString(container->kind()), "'"))), + ValueKindToString(container->kind()), "'"))), std::move(trail)}; } } diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 9af0d05b1..64feec846 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -43,6 +43,7 @@ namespace { using ::cel::FunctionEvaluationContext; using ::cel::Handle; using ::cel::Value; +using ::cel::ValueKindToKind; // Determine if the overload should be considered. Overloads that can consume // errors or unknown sets must be allowed as a non-strict function. @@ -215,7 +216,8 @@ absl::StatusOr> AbstractFunctionStep::DoEvaluate( if (!arg_types.empty()) { absl::StrAppend(&arg_types, ", "); } - absl::StrAppend(&arg_types, CelValue::TypeName(arg->kind())); + absl::StrAppend(&arg_types, + CelValue::TypeName(ValueKindToKind(arg->kind()))); } // If no errors or unknowns in input args, create new CelError for missing @@ -301,11 +303,12 @@ absl::StatusOr LazyFunctionStep::ResolveFunction( const ExecutionFrame* frame) const { ResolveResult result = absl::nullopt; - std::vector arg_types(num_arguments_); + std::vector arg_types(num_arguments_); - std::transform( - input_args.begin(), input_args.end(), arg_types.begin(), - [](const cel::Handle& value) { return value->kind(); }); + std::transform(input_args.begin(), input_args.end(), arg_types.begin(), + [](const cel::Handle& value) { + return ValueKindToKind(value->kind()); + }); CelFunctionDescriptor matcher{name_, receiver_style_, arg_types}; diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 07c0b93e5..ca2eb545e 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -33,12 +33,12 @@ namespace { using ::cel::ErrorValue; using ::cel::Handle; -using ::cel::Kind; using ::cel::MapValue; using ::cel::NullValue; using ::cel::StructValue; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::ValueKind; using ::cel::extensions::ProtoMemoryManager; using ::cel::interop_internal::CreateBoolValue; using ::cel::interop_internal::CreateError; @@ -204,11 +204,11 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { // Handle test only Select. if (test_field_presence_) { switch (arg->kind()) { - case Kind::kMap: + case ValueKind::kMap: frame->value_stack().PopAndPush(TestOnlySelect( arg.As(), field_, frame->memory_manager())); return absl::OkStatus(); - case Kind::kMessage: + case ValueKind::kMessage: frame->value_stack().PopAndPush( TestOnlySelect(arg.As(), field_, frame->memory_manager(), frame->type_manager())); @@ -221,7 +221,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { // Normal select path. // Select steps can be applied to either maps or messages switch (arg->kind()) { - case Kind::kStruct: { + case ValueKind::kStruct: { CEL_ASSIGN_OR_RETURN(Handle result, CreateValueFromField(arg.As(), frame)); frame->value_stack().PopAndPush(std::move(result), @@ -229,7 +229,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } - case CelValue::Type::kMap: { + case ValueKind::kMap: { const auto& cel_map = arg.As(); auto cel_field = CreateStringValueFromView(field_); CEL_ASSIGN_OR_RETURN( diff --git a/eval/internal/interop.cc b/eval/internal/interop.cc index a2163f778..8e18e6bab 100644 --- a/eval/internal/interop.cc +++ b/eval/internal/interop.cc @@ -425,9 +425,9 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, const Handle& value, bool unchecked) { switch (value->kind()) { - case Kind::kNullType: + case ValueKind::kNullType: return CelValue::CreateNull(); - case Kind::kError: { + case ValueKind::kError: { if (base_internal::Metadata::IsTrivial(*value)) { return CelValue::CreateError( ErrorValueAccess::value_ptr(*value.As())); @@ -435,11 +435,7 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, return CelValue::CreateError(google::protobuf::Arena::Create( arena, value.As()->value())); } - case Kind::kDyn: - break; - case Kind::kAny: - break; - case Kind::kType: { + case ValueKind::kType: { // Should be fine, so long as we are using an arena allocator. // We can only transport legacy type values. if (base_internal::Metadata::GetInlineVariant< @@ -452,30 +448,30 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, return CelValue::CreateCelTypeView(*type_name); } - case Kind::kBool: + case ValueKind::kBool: return CelValue::CreateBool(value.As()->value()); - case Kind::kInt: + case ValueKind::kInt: return CelValue::CreateInt64(value.As()->value()); - case Kind::kUint: + case ValueKind::kUint: return CelValue::CreateUint64(value.As()->value()); - case Kind::kDouble: + case ValueKind::kDouble: return CelValue::CreateDouble(value.As()->value()); - case Kind::kString: + case ValueKind::kString: return absl::visit(StringValueToLegacyVisitor{arena}, GetStringValueRep(value.As())); - case Kind::kBytes: + case ValueKind::kBytes: return absl::visit(BytesValueToLegacyVisitor{arena}, GetBytesValueRep(value.As())); - case Kind::kEnum: + case ValueKind::kEnum: break; - case Kind::kDuration: + case ValueKind::kDuration: return unchecked ? CelValue::CreateUncheckedDuration( value.As()->value()) : CelValue::CreateDuration(value.As()->value()); - case Kind::kTimestamp: + case ValueKind::kTimestamp: return CelValue::CreateTimestamp(value.As()->value()); - case Kind::kList: { + case ValueKind::kList: { if (value->Is()) { // Fast path. return CelValue::CreateList(reinterpret_cast( @@ -484,7 +480,7 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, return CelValue::CreateList( google::protobuf::Arena::Create(arena, value.As())); } - case Kind::kMap: { + case ValueKind::kMap: { if (value->Is()) { // Fast path. return CelValue::CreateMap(reinterpret_cast( @@ -493,7 +489,7 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, return CelValue::CreateMap( google::protobuf::Arena::Create(arena, value.As())); } - case Kind::kStruct: { + case ValueKind::kStruct: { if (value->Is()) { // "Legacy". uintptr_t message = LegacyStructValueAccess::Message( @@ -513,7 +509,7 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, return absl::UnimplementedError( "only legacy struct types and values can be used for interop"); } - case Kind::kUnknown: { + case ValueKind::kUnknown: { if (base_internal::Metadata::IsTrivial(*value)) { return CelValue::CreateUnknownSet( UnknownValueAccess::value_ptr(*value.As())); @@ -525,9 +521,9 @@ absl::StatusOr ToLegacyValue(google::protobuf::Arena* arena, default: break; } - return absl::UnimplementedError( - absl::StrCat("conversion from cel::Value to CelValue for type ", - KindToString(value->kind()), " is not yet implemented")); + return absl::UnimplementedError(absl::StrCat( + "conversion from cel::Value to CelValue for type ", + ValueKindToString(value->kind()), " is not yet implemented")); } Handle CreateNullValue() { diff --git a/extensions/protobuf/struct_value.cc b/extensions/protobuf/struct_value.cc index 6ac6eac25..9e445c692 100644 --- a/extensions/protobuf/struct_value.cc +++ b/extensions/protobuf/struct_value.cc @@ -1373,10 +1373,10 @@ void ProtoDebugStringMap(std::string& out, const google::protobuf::Message& mess bool ToProtoMapKey(google::protobuf::MapKey& key, const Handle& value, const google::protobuf::FieldDescriptor& field) { switch (value->kind()) { - case Kind::kBool: + case ValueKind::kBool: key.SetBoolValue(value.As()->value()); break; - case Kind::kInt: { + case ValueKind::kInt: { int64_t cpp_key = value.As()->value(); const auto* key_desc = field.message_type()->map_key(); switch (key_desc->cpp_type()) { @@ -1394,7 +1394,7 @@ bool ToProtoMapKey(google::protobuf::MapKey& key, const Handle& value, ABSL_UNREACHABLE(); } } break; - case Kind::kUint: { + case ValueKind::kUint: { uint64_t cpp_key = value.As()->value(); const auto* key_desc = field.message_type()->map_key(); switch (key_desc->cpp_type()) { @@ -1411,7 +1411,7 @@ bool ToProtoMapKey(google::protobuf::MapKey& key, const Handle& value, ABSL_UNREACHABLE(); } } break; - case Kind::kString: + case ValueKind::kString: key.SetStringValue(value.As()->ToString()); break; default: From ca2fdb3d05a1c630feb925934f4783bc20a4f939 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 23 May 2023 14:07:04 -0700 Subject: [PATCH 303/303] Update the default type provider (protobuf reflection impl) to return a specific type info instance instead of the generic one. PiperOrigin-RevId: 534552333 --- eval/public/structs/BUILD | 1 + eval/public/structs/legacy_type_info_apis.h | 3 + .../structs/proto_message_type_adapter.cc | 28 ++++++++ .../structs/proto_message_type_adapter.h | 23 ++++++- .../proto_message_type_adapter_test.cc | 66 +++++++++++++++++++ .../protobuf_descriptor_type_provider.cc | 32 ++++----- .../protobuf_descriptor_type_provider.h | 8 ++- 7 files changed, 142 insertions(+), 19 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index b8edab5eb..bba85ec94 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -259,6 +259,7 @@ cc_test( ":legacy_type_adapter", ":legacy_type_info_apis", ":proto_message_type_adapter", + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h index 49aff7e24..d9d145ffb 100644 --- a/eval/public/structs/legacy_type_info_apis.h +++ b/eval/public/structs/legacy_type_info_apis.h @@ -31,6 +31,9 @@ class LegacyTypeMutationApis; // Provides ability to obtain field access apis, type info, and debug // representation of a message. // +// The message parameter may wrap a nullptr to request generic accessors / +// mutators for the TypeInfo instance if it is available. +// // This is implemented as a separate class from LegacyTypeAccessApis to resolve // cyclic dependency between CelValue (which needs to access these apis to // provide DebugString and ObtainCelTypename) and LegacyTypeAccessApis (which diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 42c7f10d9..74b32f6f2 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -301,6 +301,34 @@ CelValue MessageCelValueFactory(const google::protobuf::Message* message) { } // namespace +std::string ProtoMessageTypeAdapter::DebugString( + const MessageWrapper& wrapped_message) const { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = cel::internal::down_cast( + wrapped_message.message_ptr()); + return message->ShortDebugString(); +} + +const std::string& ProtoMessageTypeAdapter::GetTypename( + const MessageWrapper& wrapped_message) const { + return descriptor_->full_name(); +} + +const LegacyTypeMutationApis* ProtoMessageTypeAdapter::GetMutationApis( + const MessageWrapper& wrapped_message) const { + // Defer checks for misuse on wrong message kind in the accessor calls. + return this; +} + +const LegacyTypeAccessApis* ProtoMessageTypeAdapter::GetAccessApis( + const MessageWrapper& wrapped_message) const { + // Defer checks for misuse on wrong message kind in the builder calls. + return this; +} + absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( bool assertion, absl::string_view field, absl::string_view detail) const { if (!assertion) { diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 12ba4ae0e..43b67f285 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ +#include #include #include "google/protobuf/descriptor.h" @@ -29,7 +30,13 @@ namespace google::api::expr::runtime { -class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, +// Implementation for legacy struct (message) type apis using reflection. +// +// Note: The type info API implementation attached to message values is +// generally the duck-typed instance to support the default behavior of +// deferring to the protobuf reflection apis on the message instance. +class ProtoMessageTypeAdapter : public LegacyTypeInfoApis, + public LegacyTypeAccessApis, public LegacyTypeMutationApis { public: ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, @@ -38,6 +45,19 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, ~ProtoMessageTypeAdapter() override = default; + // Implement LegacyTypeInfoApis + std::string DebugString(const MessageWrapper& wrapped_message) const override; + + const std::string& GetTypename( + const MessageWrapper& wrapped_message) const override; + + const LegacyTypeAccessApis* GetAccessApis( + const MessageWrapper& wrapped_message) const override; + + const LegacyTypeMutationApis* GetMutationApis( + const MessageWrapper& wrapped_message) const override; + + // Implement LegacyTypeMutation APIs. absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override; @@ -52,6 +72,7 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, cel::MemoryManager& memory_manager, CelValue::MessageWrapper::Builder instance) const override; + // Implement LegacyTypeAccessAPIs. absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 0ddabcb46..ad90a279c 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -20,6 +20,7 @@ #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "absl/status/status.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" @@ -671,5 +672,70 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { StatusIs(absl::StatusCode::kInternal)); } +TEST(ProtoMesssageTypeAdapter, TypeInfoDebug) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + ProtoMemoryManager manager(&arena); + + TestMessage message; + message.set_int64_value(42); + EXPECT_THAT(adapter.DebugString(MessageWrapper(&message, &adapter)), + HasSubstr(message.ShortDebugString())); + + EXPECT_THAT(adapter.DebugString(MessageWrapper()), + HasSubstr("")); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoName) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + ProtoMemoryManager manager(&arena); + + EXPECT_EQ(adapter.GetTypename(MessageWrapper()), + "google.api.expr.runtime.TestMessage"); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + ProtoMemoryManager manager(&arena); + + const LegacyTypeMutationApis* api = adapter.GetMutationApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, + api->NewInstance(manager)); + EXPECT_NE(dynamic_cast(builder.message_ptr()), nullptr); +} + +TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + ProtoMemoryManager manager(&arena); + + TestMessage message; + message.set_int64_value(42); + CelValue::MessageWrapper wrapped(&message, &adapter); + + const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); + ASSERT_NE(api, nullptr); + + EXPECT_THAT(api->GetField("int64_value", wrapped, + ProtoWrapperTypeOptions::kUnsetNull, manager), + IsOkAndHolds(test::IsCelInt64(42))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 03313d733..5c18ce3be 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -25,18 +25,7 @@ namespace google::api::expr::runtime { absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::string_view name) const { - const ProtoMessageTypeAdapter* result = nullptr; - { - absl::MutexLock lock(&mu_); - auto it = type_cache_.find(name); - if (it != type_cache_.end()) { - result = it->second.get(); - } else { - auto type_provider = GetType(name); - result = type_provider.get(); - type_cache_[name] = std::move(type_provider); - } - } + const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { return absl::nullopt; } @@ -47,11 +36,11 @@ absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::optional ProtobufDescriptorProvider::ProvideLegacyTypeInfo( absl::string_view name) const { - return &GetGenericProtoTypeInfoInstance(); + return GetTypeAdapter(name); } -std::unique_ptr ProtobufDescriptorProvider::GetType( - absl::string_view name) const { +std::unique_ptr +ProtobufDescriptorProvider::CreateTypeAdapter(absl::string_view name) const { const google::protobuf::Descriptor* descriptor = descriptor_pool_->FindMessageTypeByName(name); if (descriptor == nullptr) { @@ -61,4 +50,17 @@ std::unique_ptr ProtobufDescriptorProvider::GetType( return std::make_unique(descriptor, message_factory_); } + +const ProtoMessageTypeAdapter* ProtobufDescriptorProvider::GetTypeAdapter( + absl::string_view name) const { + absl::MutexLock lock(&mu_); + auto it = type_cache_.find(name); + if (it != type_cache_.end()) { + return it->second.get(); + } + auto type_provider = CreateTypeAdapter(name); + const ProtoMessageTypeAdapter* result = type_provider.get(); + type_cache_[name] = std::move(type_provider); + return result; +} } // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h index e85ebb85d..b669af662 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.h +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -45,11 +45,13 @@ class ProtobufDescriptorProvider : public LegacyTypeProvider { absl::string_view name) const override; private: - // Run a lookup if the type adapter hasn't already been built. - // returns nullptr if not found. - std::unique_ptr GetType( + // Create a new type instance if found in the registered descriptor pool. + // Otherwise, returns nullptr. + std::unique_ptr CreateTypeAdapter( absl::string_view name) const; + const ProtoMessageTypeAdapter* GetTypeAdapter(absl::string_view name) const; + const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::MessageFactory* message_factory_; mutable absl::flat_hash_map