From 7e284e92a0420a2e311426b82c675034476feec5 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 9 Jul 2025 14:10:23 -0700 Subject: [PATCH 1/4] Add clarifying notes to CEL codelab exercise 4 solution. PiperOrigin-RevId: 781204887 --- codelab/solutions/exercise4.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/codelab/solutions/exercise4.cc b/codelab/solutions/exercise4.cc index f56789a4d..244fdac05 100644 --- a/codelab/solutions/exercise4.cc +++ b/codelab/solutions/exercise4.cc @@ -91,6 +91,9 @@ absl::StatusOr> MakeConfiguredCompiler() { // Codelab part 1: // Add a declaration for the map.contains(string, V) function. auto& checker_builder = builder->GetCheckerBuilder(); + // Note: we use MakeMemberOverloadDecl instead of MakeOverloadDecl + // because the function is receiver style, meaning that it is called as + // e1.f(e2) instead of f(e1, e2). CEL_ASSIGN_OR_RETURN( cel::FunctionDecl decl, cel::MakeFunctionDecl( @@ -100,6 +103,8 @@ absl::StatusOr> MakeConfiguredCompiler() { cel::MapType(checker_builder.arena(), cel::StringType(), cel::TypeParamType("V")), cel::StringType(), cel::TypeParamType("V")))); + // Note: we use MergeFunction instead of AddFunction because we are adding + // an overload to an already declared function with the same name. CEL_RETURN_IF_ERROR(checker_builder.MergeFunction(decl)); return builder->Build(); } @@ -135,7 +140,7 @@ class Evaluator { if (bool value; result.GetValue(&value)) { return value; - } else if (const CelError * value; result.GetValue(&value)) { + } else if (const CelError* value; result.GetValue(&value)) { return *value; } else { return absl::InvalidArgumentError( From 69fb7f760e2dc7ebe89911ffc907cf7fdb864615 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 10 Jul 2025 16:59:25 -0700 Subject: [PATCH 2/4] Load proto_library and cc_proto_library rules from protobuf and bump minimum version to 28.3 PiperOrigin-RevId: 781737998 --- MODULE.bazel | 2 +- eval/tests/BUILD | 3 +++ eval/testutil/BUILD | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/MODULE.bazel b/MODULE.bazel index 565d57a91..a676906cc 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -33,7 +33,7 @@ bazel_dep( ) bazel_dep( name = "protobuf", - version = "27.0", + version = "28.3", repo_name = "com_google_protobuf", ) bazel_dep( diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 0f9997bb2..c98c02206 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -2,6 +2,9 @@ # # +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/eval/testutil/BUILD b/eval/testutil/BUILD index 5d80af860..cb35e6752 100644 --- a/eval/testutil/BUILD +++ b/eval/testutil/BUILD @@ -1,3 +1,6 @@ +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") + # This package contains testing utility code package(default_visibility = ["//visibility:public"]) From 2c59cf330feaeb48f7088535b93f3f9226ded228 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Fri, 11 Jul 2025 17:07:25 -0700 Subject: [PATCH 3/4] Adding the CEL regex extensions PiperOrigin-RevId: 782160181 --- extensions/BUILD | 48 +++++ extensions/regex_ext.cc | 263 +++++++++++++++++++++++++++ extensions/regex_ext.h | 94 ++++++++++ extensions/regex_ext_test.cc | 332 +++++++++++++++++++++++++++++++++++ 4 files changed, 737 insertions(+) create mode 100644 extensions/regex_ext.cc create mode 100644 extensions/regex_ext.h create mode 100644 extensions/regex_ext_test.cc diff --git a/extensions/BUILD b/extensions/BUILD index 16f2c0be4..f127e1eed 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -661,6 +661,54 @@ cc_library( ], ) +cc_library( + name = "regex_ext", + srcs = ["regex_ext.cc"], + hdrs = ["regex_ext.h"], + deps = [ + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "regex_ext_test", + srcs = ["regex_ext_test.cc"], + deps = [ + ":regex_ext", + "//common:value", + "//common:value_testing", + "//extensions/protobuf:runtime_adapter", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "formatting_test", srcs = ["formatting_test.cc"], diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc new file mode 100644 index 000000000..54cb3e24d --- /dev/null +++ b/extensions/regex_ext.cc @@ -0,0 +1,263 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_ext.h" + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "common/value.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "re2/re2.h" + +namespace cel::extensions { +namespace { + +Value Extract(const StringValue& target, const StringValue& regex, + const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, + google::protobuf::MessageFactory* ABSL_NONNULL message_factory, + google::protobuf::Arena* ABSL_NONNULL arena) { + std::string target_scratch; + std::string regex_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } + const int group_count = re2.NumberOfCapturingGroups(); + if (group_count > 1) { + return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "regular expression has more than one capturing group: %s", + regex_view))); + } + + // Space for the full match (\0) and the first capture group (\1). + absl::string_view submatches[2]; + if (re2.Match(target_view, 0, target_view.length(), RE2::UNANCHORED, + submatches, 2)) { + // Return the capture group if it exists else return the full match. + const absl::string_view result_view = + (group_count == 1) ? submatches[1] : submatches[0]; + return OptionalValue::Of(StringValue::From(result_view, arena), arena); + } + + return OptionalValue::None(); +} + +Value ExtractAll(const StringValue& target, const StringValue& regex, + const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, + google::protobuf::MessageFactory* ABSL_NONNULL message_factory, + google::protobuf::Arena* ABSL_NONNULL arena) { + std::string target_scratch; + std::string regex_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } + const int group_count = re2.NumberOfCapturingGroups(); + if (group_count > 1) { + return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( + "regular expression has more than one capturing group: %s", + regex_view))); + } + + auto builder = NewListValueBuilder(arena); + absl::string_view temp_target = target_view; + + // Space for the full match (\0) and the first capture group (\1). + absl::string_view submatches[2]; + const int group_to_extract = (group_count == 1) ? 1 : 0; + + while (re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, + submatches, group_count + 1)) { + const absl::string_view& full_match = submatches[0]; + const absl::string_view& desired_capture = submatches[group_to_extract]; + + // Avoid infinite loops on zero-length matches + if (full_match.empty()) { + if (temp_target.empty()) { + break; + } + temp_target.remove_prefix(1); + continue; + } + + if (group_count == 1 && desired_capture.empty()) { + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + continue; + } + + absl::Status status = + builder->Add(StringValue::From(desired_capture, arena)); + if (!status.ok()) { + return ErrorValue(status); + } + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + } + + return std::move(*builder).Build(); +} + +Value ReplaceAll(const StringValue& target, const StringValue& regex, + const StringValue& replacement, + const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, + google::protobuf::MessageFactory* ABSL_NONNULL message_factory, + google::protobuf::Arena* ABSL_NONNULL arena) { + std::string target_scratch; + std::string regex_scratch; + std::string replacement_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view replacement_view = + replacement.ToStringView(&replacement_scratch); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } + + std::string error_string; + if (!re2.CheckRewriteString(replacement_view, &error_string)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("invalid replacement string: %s", error_string))); + } + + std::string output(target_view); + RE2::GlobalReplace(&output, re2, replacement_view); + + return StringValue::From(std::move(output), arena); +} + +Value ReplaceN(const StringValue& target, const StringValue& regex, + const StringValue& replacement, int64_t count, + const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, + google::protobuf::MessageFactory* ABSL_NONNULL message_factory, + google::protobuf::Arena* ABSL_NONNULL arena) { + if (count == 0) { + return target; + } + if (count < 0) { + return ReplaceAll(target, regex, replacement, descriptor_pool, + message_factory, arena); + } + + std::string target_scratch; + std::string regex_scratch; + std::string replacement_scratch; + absl::string_view target_view = target.ToStringView(&target_scratch); + absl::string_view regex_view = regex.ToStringView(®ex_scratch); + absl::string_view replacement_view = + replacement.ToStringView(&replacement_scratch); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } + std::string error_string; + if (!re2.CheckRewriteString(replacement_view, &error_string)) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("invalid replacement string: %s", error_string))); + } + + std::string output; + absl::string_view temp_target = target_view; + int replaced_count = 0; + // RE2's Rewrite only supports substitutions for groups \0 through \9. + absl::string_view match[10]; + int nmatch = std::min(9, re2.NumberOfCapturingGroups()) + 1; + + while (replaced_count < count && + re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, match, + nmatch)) { + absl::string_view full_match = match[0]; + + output.append(temp_target.data(), full_match.data() - temp_target.data()); + + if (!re2.Rewrite(&output, replacement_view, match, nmatch)) { + // This should ideally not happen given CheckRewriteString passed + return ErrorValue(absl::InternalError("rewrite failed unexpectedly")); + } + + temp_target.remove_prefix(full_match.data() - temp_target.data() + + full_match.length()); + replaced_count++; + } + + output.append(temp_target.data(), temp_target.length()); + + return StringValue::From(std::move(output), arena); +} + +} // namespace + +absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) { + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload("regex.extract", &Extract, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter, StringValue, StringValue>:: + RegisterGlobalOverload("regex.extractAll", &ExtractAll, registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, + StringValue>::RegisterGlobalOverload("regex.replace", &ReplaceAll, + registry))); + CEL_RETURN_IF_ERROR( + (QuaternaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, StringValue, + int64_t>::RegisterGlobalOverload("regex.replace", &ReplaceN, + registry))); + return absl::OkStatus(); +} + +absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options) { + if (options.enable_regex) { + CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(registry)); + } + return absl::OkStatus(); +} + +absl::Status RegisterRegexExtensionFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options) { + return RegisterRegexExtensionFunctions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h new file mode 100644 index 000000000..29018779b --- /dev/null +++ b/extensions/regex_ext.h @@ -0,0 +1,94 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This extension depends on the CEL optional type. Please ensure that the +// EnableOptionalTypes is called when using regex extensions. +// +// # Replace +// +// The `regex.replace` function replaces all non-overlapping substring of a +// regex pattern in the target string with the given replacement string. +// Optionally, you can limit the number of replacements by providing a count +// argument. When the count is a negative number, the function acts as replace +// all. Only numeric (\N) capture group references are supported in the +// replacement string, with validation for correctness. Backslashed-escaped +// digits (\1 to \9) within the replacement argument can be used to insert text +// matching the corresponding parenthesized group in the regexp pattern. An +// error will be thrown for invalid regex or replace string. +// +// regex.replace(target: string, pattern: string, +// replacement: string) -> string +// regex.replace(target: string, pattern: string, +// replacement: string, count: int) -> string +// +// Examples: +// +// regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi' +// regex.replace('banana', 'a', 'x', 0) == 'banana' +// regex.replace('banana', 'a', 'x', 1) == 'bxnana' +// regex.replace('banana', 'a', 'x', -12) == 'bxnxnx' +// regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo' +// regex.replace('test', '(.)', r'\2') \\ Runtime Error invalid replace +// string regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid +// +// # Extract +// +// The `regex.extract` function returns the first match of a regex pattern in a +// string. If no match is found, it returns an optional none value. An error +// will be thrown for invalid regex or for multiple capture groups. +// +// regex.extract(target: string, pattern: string) -> optional +// +// Examples: +// +// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A') +// regex.extract('HELLO', 'hello') == optional.empty() +// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error +// multiple capture group +// +// # Extract All +// +// The `regex.extractAll` function returns a list of all matches of a regex +// pattern in a target string. If no matches are found, it returns an empty +// list. An error will be thrown for invalid regex or for multiple capture +// groups. +// +// regex.extractAll(target: string, pattern: string) -> list +// +// Examples: +// +// regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456'] +// regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error +// multiple capture group + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register extension functions for regular expressions. +absl::Status RegisterRegexExtensionFunctions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); +absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, + const RuntimeOptions& options); + +} // namespace cel::extensions +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc new file mode 100644 index 000000000..c626045ea --- /dev/null +++ b/extensions/regex_ext_test.cc @@ -0,0 +1,332 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/regex_ext.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/extension_set.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::test::ErrorValueIs; +using ::google::api::expr::parser::Parse; +using test::BoolValueIs; +using test::OptionalValueIs; +using test::OptionalValueIsEmpty; +using test::StringValueIs; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +enum class EvaluationType { + kBoolTrue, + kOptionalValue, + kOptionalNone, + kRuntimeError, + kUnknownStaticError, + kInvalidArgStaticError +}; + +struct RegexExtTestCase { + EvaluationType evaluation_type; + std::string expr; + std::string expected_result = ""; +}; + +class RegexExtTest : public TestWithParam { + public: + void SetUp() override { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + RegisterRegexExtensionFunctions(builder.function_registry(), options), + IsOk()); + ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); + } + + absl::StatusOr TestEvaluate(const std::string& expr_string) { + CEL_ASSIGN_OR_RETURN(auto parsed_expr, Parse(expr_string)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + cel::extensions::ProtobufRuntimeAdapter::CreateProgram( + *runtime_, parsed_expr)); + Activation activation; + return program->Evaluate(&arena_, activation); + } + + google::protobuf::Arena arena_; + std::unique_ptr runtime_; +}; + +std::vector regexTestCases() { + return { + // Tests for extract Function + {EvaluationType::kOptionalValue, + R"(regex.extract('hello world', 'hello (.*)'))", "world"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('item-A, item-B', r'item-(\w+)'))", "A"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('The color is red', r'The color is (\w+)'))", "red"}, + {EvaluationType::kOptionalValue, + R"(regex.extract('The color is red', r'The color is \w+'))", + "The color is red"}, + {EvaluationType::kOptionalValue, "regex.extract('brand', 'brand')", + "brand"}, + {EvaluationType::kOptionalNone, + "regex.extract('hello world', 'goodbye (.*)')"}, + {EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"}, + {EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"}, + + // Tests for extractAll Function + {EvaluationType::kBoolTrue, + "regex.extractAll('id:123, id:456', 'assa') == []"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('id:123, id:456', r'id:\d+') == ['id:123','id:456'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('Files: f_1.txt, f_2.csv', r'f_(\d+)')==['1','2'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('testuser@', '(?P.*)@') == ['testuser'])"}, + {EvaluationType::kBoolTrue, + R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', + '(?P.*)@') == ['t@gmail.com, a@y.com, 22'])cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', + r'(?P\w+)@') == ['t','a', '22'])cel"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('item:a1, topic:b2', + r'(?:item:|topic:)([a-z]\d)') == ['a1', 'b2'])"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('val=a, val=, val=c', 'val=([^,]*)')==['a','c'])"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('key=, key=, key=', 'key=([^,]*)') == []"}, + {EvaluationType::kBoolTrue, + R"(regex.extractAll('a b c', r'(\S*)\s*') == ['a', 'b', 'c'])"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('abc', 'a|b*') == ['a','b']"}, + {EvaluationType::kBoolTrue, + "regex.extractAll('abc', 'a|(b)|c*') == ['b']"}, + + // Tests for replace Function + {EvaluationType::kBoolTrue, + "regex.replace('abc', '$', '_end') == 'abc_end'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a-b', r'\b', '|') == '|a|-|b|')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('foo bar', 'foo', r'\\') == '\\ bar')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'ana', 'x') == 'bxna'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc', 'b(.)', r'x\1') == 'axc')"}, + {EvaluationType::kBoolTrue, + "regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('ac', 'a(b)?c', r'[\1]') == '[]')"}, + {EvaluationType::kBoolTrue, + "regex.replace('apple pie', 'p', 'X') == 'aXXle Xie'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('remove all spaces', r'\s', '') == + 'removeallspaces')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('digit:99919291992', r'\d+', '3') == 'digit:3')"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('foo bar baz', r'\w+', r'(\0)') == + '(foo) (bar) (baz)')cel"}, + {EvaluationType::kBoolTrue, "regex.replace('', 'a', 'b') == ''"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('User: Alice, Age: 30', + r'User: (?P\w+), Age: (?P\d+)', + '${name} is ${age} years old') == '${name} is ${age} years old')cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('User: Alice, Age: 30', + r'User: (?P\w+), Age: (?P\d+)', r'\1 is \2 years old') == + 'Alice is 30 years old')cel"}, + {EvaluationType::kBoolTrue, + "regex.replace('hello ☃', '☃', '❄') == 'hello ❄'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \1') == + 'value: 123')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x') == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + R"(regex.replace(regex.replace('%(foo) %(bar) %2', r'%\((\w+)\)', + r'${\1}'),r'%(\d+)', r'$\1') == '${foo} ${bar} $2')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\1') == r'\1 def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\2') == r'\2 def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\{word}') == '\\{word} def')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('abc def', r'(abc)', r'\\word') == '\\word def')"}, + {EvaluationType::kBoolTrue, + "regex.replace('abc', '^', 'start_') == 'start_abc'"}, + + // Tests for replace Function with count variable + {EvaluationType::kBoolTrue, + R"(regex.replace('foofoo', 'foo', 'bar', + 9223372036854775807) == 'barbar')"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 0) == 'banana'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 1) == 'bxnana'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 2) == 'bxnxna'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', 100) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', -1) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + "regex.replace('banana', 'a', 'x', -100) == 'bxnxnx'"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', + r'\2-\1', 1) == 'dog-cat dog-cat cat-dog dog-cat')cel"}, + {EvaluationType::kBoolTrue, + R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', + r'\2-\1', 2) == 'dog-cat dog-cat dog-cat dog-cat')cel"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a.b.c', r'\.', '-', 1) == 'a-b.c')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('a.b.c', r'\.', '-', -1) == 'a-b-c')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('123456789ABC', + '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)','X', 1) + == 'X')"}, + {EvaluationType::kBoolTrue, + R"(regex.replace('123456789ABC', + '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)', + r'\1-\9-X', 1) == '1-9-X')"}, + + // Static Errors + {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', 1)", + "No matching overloads found : regex.replace(string, string, int64)"}, + {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', '1','')", + "No matching overloads found : regex.replace(string, string, string, " + "string)"}, + {EvaluationType::kUnknownStaticError, "regex.extract('foo bar', 1)", + "No matching overloads found : regex.extract(string, int64)"}, + {EvaluationType::kInvalidArgStaticError, + "regex.extract('foo bar', 1, 'bar')", + "No overload found in reference resolve step for extract"}, + {EvaluationType::kInvalidArgStaticError, "regex.extractAll()", + "No overload found in reference resolve step for extractAll"}, + + // Runtime Errors + {EvaluationType::kRuntimeError, R"(regex.extract('foo', 'fo(o+)(abc'))", + "given regex is invalid: missing ): fo(o+)(abc"}, + {EvaluationType::kRuntimeError, R"(regex.extractAll('foo bar', '[a-z'))", + "given regex is invalid: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('foo bar', '[a-z', 'a'))", + "given regex is invalid: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('foo bar', '[a-z', 'a', 1))", + "given regex is invalid: missing ]: [a-z"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \values'))", + R"(invalid replacement string: Rewrite schema error: '\' must be followed by a digit or '\'.)"}, + {EvaluationType::kRuntimeError, R"(regex.replace('test', '(t)', '\\2'))", + "invalid replacement string: Rewrite schema requests 2 matches, but " + "the regexp only has 1 parenthesized subexpressions"}, + {EvaluationType::kRuntimeError, + R"(regex.replace('id=123', r'id=(?P\d+)', '\\', 1))", + R"(invalid replacement string: Rewrite schema error: '\' not allowed at end.)"}, + {EvaluationType::kRuntimeError, + R"(regex.extract('phone: 415-5551212', r'phone: ((\d{3})-)?'))", + R"(regular expression has more than one capturing group: phone: ((\d{3})-)?)"}, + {EvaluationType::kRuntimeError, + R"(regex.extractAll('testuser@testdomain', '(.*)@([^.]*)'))", + R"(regular expression has more than one capturing group: (.*)@([^.]*))"}, + }; +} + +TEST_P(RegexExtTest, RegexExtTests) { + const RegexExtTestCase& test_case = GetParam(); + auto result = TestEvaluate(test_case.expr); + + switch (test_case.evaluation_type) { + case EvaluationType::kRuntimeError: + EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_result))))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kUnknownStaticError: + EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr(test_case.expected_result))))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kInvalidArgStaticError: + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.expected_result))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kOptionalNone: + EXPECT_THAT(result, IsOkAndHolds(OptionalValueIsEmpty())) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kOptionalValue: + EXPECT_THAT(result, IsOkAndHolds(OptionalValueIs( + StringValueIs(test_case.expected_result)))) + << "Expression: " << test_case.expr; + break; + case EvaluationType::kBoolTrue: + EXPECT_THAT(result, IsOkAndHolds(BoolValueIs(true))) + << "Expression: " << test_case.expr; + break; + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest, + ValuesIn(regexTestCases())); +} // namespace +} // namespace cel::extensions From 92adcfb94ade7a31eeb9c5595a1609df2f4a1476 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Thu, 17 Jul 2025 11:59:10 -0700 Subject: [PATCH 4/4] Adding the decls and checker lib to CEL regex extensions PiperOrigin-RevId: 784263673 --- eval/compiler/flat_expr_builder.h | 2 + extensions/BUILD | 17 +++++- extensions/regex_ext.cc | 90 +++++++++++++++++++++++++++---- extensions/regex_ext.h | 31 +++++++++-- extensions/regex_ext_test.cc | 84 +++++++++++++++++++++++++++-- 5 files changed, 203 insertions(+), 21 deletions(-) diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 758865769..50c0bd9b0 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -93,6 +93,8 @@ class FlatExprBuilder { // `optional_type` handling is needed. void enable_optional_types() { enable_optional_types_ = true; } + bool optional_types_enabled() const { return enable_optional_types_; } + private: const cel::TypeProvider& GetTypeProvider() const; diff --git a/extensions/BUILD b/extensions/BUILD index f127e1eed..c448f5366 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -666,13 +666,22 @@ cc_library( srcs = ["regex_ext.cc"], hdrs = ["regex_ext.h"], deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", "//common:value", + "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", + "//internal:casts", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", - "//runtime:runtime_options", + "//runtime:runtime_builder", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -688,8 +697,12 @@ cc_test( srcs = ["regex_ext_test.cc"], deps = [ ":regex_ext", + "//checker:standard_library", + "//checker:validation_result", "//common:value", "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", "//extensions/protobuf:runtime_adapter", "//internal:status_macros", "//internal:testing", @@ -699,12 +712,12 @@ cc_test( "//runtime:activation", "//runtime:optional_types", "//runtime:reference_resolver", - "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index 54cb3e24d..c2766c2c2 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -16,22 +16,30 @@ #include #include -#include #include #include +#include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" #include "common/value.h" +#include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" -#include "runtime/runtime_options.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -40,6 +48,8 @@ namespace cel::extensions { namespace { +using ::cel::checker_internal::BuiltinsArena; + Value Extract(const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, google::protobuf::MessageFactory* ABSL_NONNULL message_factory, @@ -222,8 +232,6 @@ Value ReplaceN(const StringValue& target, const StringValue& regex, return StringValue::From(std::move(output), arena); } -} // namespace - absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: @@ -244,10 +252,61 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) { return absl::OkStatus(); } -absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - if (options.enable_regex) { - CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(registry)); +const Type& OptionalStringType() { + static absl::NoDestructor kInstance( + OptionalType(BuiltinsArena(), StringType())); + return *kInstance; +} + +const Type& ListStringType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), StringType())); + return *kInstance; +} + +absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_decl, + MakeFunctionDecl( + "regex.extract", + MakeOverloadDecl("regex_extract_string_string", OptionalStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_all_decl, + MakeFunctionDecl( + "regex.extractAll", + MakeOverloadDecl("regex_extractAll_string_string", ListStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl replace_decl, + MakeFunctionDecl( + "regex.replace", + MakeOverloadDecl("regex_replace_string_string_string", StringType(), + StringType(), StringType(), StringType()), + MakeOverloadDecl("regex_replace_string_string_string_int", + StringType(), StringType(), StringType(), + StringType(), IntType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl)); + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) { + auto& runtime = cel::internal::down_cast( + runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); + if (!runtime.expr_builder().optional_types_enabled()) { + return absl::InvalidArgumentError( + "regex extensions requires the optional types to be enabled"); + } + if (runtime.expr_builder().options().enable_regex) { + CEL_RETURN_IF_ERROR( + RegisterRegexExtensionFunctions(builder.function_registry())); } return absl::OkStatus(); } @@ -255,9 +314,18 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options) { - return RegisterRegexExtensionFunctions( - registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + if (!options.enable_regex) { + return RegisterRegexExtensionFunctions(registry->InternalGetRegistry()); + } + return absl::OkStatus(); +} + +CheckerLibrary RegexExtCheckerLibrary() { + return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls}; +} + +CompilerLibrary RegexExtCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } } // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h index 29018779b..b5da5c588 100644 --- a/extensions/regex_ext.h +++ b/extensions/regex_ext.h @@ -76,10 +76,11 @@ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ #include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" -#include "runtime/function_registry.h" -#include "runtime/runtime_options.h" +#include "runtime/runtime_builder.h" namespace cel::extensions { @@ -87,8 +88,30 @@ namespace cel::extensions { absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); -absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder); + +// Type check declarations for the regex extension library. +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CheckerLibrary RegexExtCheckerLibrary(); + +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CompilerLibrary RegexExtCompilerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index c626045ea..42971e880 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -22,8 +22,13 @@ #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" #include "common/value.h" #include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -33,7 +38,6 @@ #include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" -#include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" @@ -84,9 +88,7 @@ class RegexExtTest : public TestWithParam { EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); - ASSERT_THAT( - RegisterRegexExtensionFunctions(builder.function_registry(), options), - IsOk()); + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk()); ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); } @@ -103,6 +105,23 @@ class RegexExtTest : public TestWithParam { std::unique_ptr runtime_; }; +TEST_F(RegexExtTest, BuildFailsWithoutOptionalSupport) { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + // Optional types are NOT enabled. + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("regex extensions requires the optional types " + "to be enabled"))); +} std::vector regexTestCases() { return { // Tests for extract Function @@ -121,6 +140,11 @@ std::vector regexTestCases() { "regex.extract('hello world', 'goodbye (.*)')"}, {EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"}, {EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').orValue('777') == '22'"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').or(optional.of('777')) == " + "optional.of('22')"}, // Tests for extractAll Function {EvaluationType::kBoolTrue, @@ -328,5 +352,57 @@ TEST_P(RegexExtTest, RegexExtTests) { INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest, ValuesIn(regexTestCases())); + +struct RegexCheckerTestCase { + std::string expr_string; + std::string error_substr; +}; + +class RegexExtCheckerLibraryTest : public TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the regex checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, + NewCompilerBuilder(descriptor_pool_)); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + std::unique_ptr compiler_; +}; + +TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr_string)); + absl::string_view error_substr = GetParam().error_substr; + EXPECT_EQ(result.IsValid(), error_substr.empty()); + + if (!error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); + } +} + +std::vector createRegexCheckerParams() { + return { + {"regex.replace('abc', 'a', 's') == 'sbc'"}, + {"regex.replace('abc', 'a', 's') == 121", + "found no matching overload for '_==_' applied to '(string, int)"}, + {"regex.replace('abc', 'j', '1', 2) == 9.0", + "found no matching overload for '_==_' applied to '(string, double)"}, + {"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, + {"regex.extract('foo bar', 'f') == 121", + "found no matching overload for '_==_' applied to " + "'(optional_type(string), int)'"}, + }; +} + +INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, + ValuesIn(createRegexCheckerParams())); } // namespace } // namespace cel::extensions