Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5580c78

Browse files
maskri17copybara-github
authored andcommitted
Adding the decls and checker lib to CEL regex extensions
PiperOrigin-RevId: 783106089
1 parent 2c59cf3 commit 5580c78

File tree

5 files changed

+203
-21
lines changed

5 files changed

+203
-21
lines changed

eval/compiler/flat_expr_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class FlatExprBuilder {
9393
// `optional_type` handling is needed.
9494
void enable_optional_types() { enable_optional_types_ = true; }
9595

96+
bool optional_types_enabled() const { return enable_optional_types_; }
97+
9698
private:
9799
const cel::TypeProvider& GetTypeProvider() const;
98100

extensions/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,13 +666,22 @@ cc_library(
666666
srcs = ["regex_ext.cc"],
667667
hdrs = ["regex_ext.h"],
668668
deps = [
669+
"//checker:type_checker_builder",
670+
"//checker/internal:builtins_arena",
671+
"//common:decl",
672+
"//common:type",
669673
"//common:value",
674+
"//compiler",
670675
"//eval/public:cel_function_registry",
671676
"//eval/public:cel_options",
677+
"//internal:casts",
672678
"//internal:status_macros",
673679
"//runtime:function_adapter",
674680
"//runtime:function_registry",
675-
"//runtime:runtime_options",
681+
"//runtime:runtime_builder",
682+
"//runtime/internal:runtime_friend_access",
683+
"//runtime/internal:runtime_impl",
684+
"@com_google_absl//absl/base:no_destructor",
676685
"@com_google_absl//absl/base:nullability",
677686
"@com_google_absl//absl/status",
678687
"@com_google_absl//absl/status:statusor",
@@ -688,8 +697,12 @@ cc_test(
688697
srcs = ["regex_ext_test.cc"],
689698
deps = [
690699
":regex_ext",
700+
"//checker:standard_library",
701+
"//checker:validation_result",
691702
"//common:value",
692703
"//common:value_testing",
704+
"//compiler",
705+
"//compiler:compiler_factory",
693706
"//extensions/protobuf:runtime_adapter",
694707
"//internal:status_macros",
695708
"//internal:testing",
@@ -699,12 +712,12 @@ cc_test(
699712
"//runtime:activation",
700713
"//runtime:optional_types",
701714
"//runtime:reference_resolver",
702-
"//runtime:runtime_builder",
703715
"//runtime:runtime_options",
704716
"//runtime:standard_runtime_builder_factory",
705717
"@com_google_absl//absl/status",
706718
"@com_google_absl//absl/status:status_matchers",
707719
"@com_google_absl//absl/status:statusor",
720+
"@com_google_absl//absl/strings:string_view",
708721
"@com_google_protobuf//:protobuf",
709722
],
710723
)

extensions/regex_ext.cc

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,30 @@
1616

1717
#include <algorithm>
1818
#include <cstdint>
19-
#include <memory>
2019
#include <string>
2120
#include <utility>
2221

22+
#include "absl/base/no_destructor.h"
2323
#include "absl/base/nullability.h"
2424
#include "absl/status/status.h"
2525
#include "absl/status/statusor.h"
2626
#include "absl/strings/str_format.h"
2727
#include "absl/strings/string_view.h"
28+
#include "checker/internal/builtins_arena.h"
29+
#include "checker/type_checker_builder.h"
30+
#include "common/decl.h"
31+
#include "common/type.h"
2832
#include "common/value.h"
33+
#include "compiler/compiler.h"
2934
#include "eval/public/cel_function_registry.h"
3035
#include "eval/public/cel_options.h"
36+
#include "internal/casts.h"
3137
#include "internal/status_macros.h"
3238
#include "runtime/function_adapter.h"
3339
#include "runtime/function_registry.h"
34-
#include "runtime/runtime_options.h"
40+
#include "runtime/internal/runtime_friend_access.h"
41+
#include "runtime/internal/runtime_impl.h"
42+
#include "runtime/runtime_builder.h"
3543
#include "google/protobuf/arena.h"
3644
#include "google/protobuf/descriptor.h"
3745
#include "google/protobuf/message.h"
@@ -40,6 +48,8 @@
4048
namespace cel::extensions {
4149
namespace {
4250

51+
using ::cel::checker_internal::BuiltinsArena;
52+
4353
Value Extract(const StringValue& target, const StringValue& regex,
4454
const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool,
4555
google::protobuf::MessageFactory* ABSL_NONNULL message_factory,
@@ -222,8 +232,6 @@ Value ReplaceN(const StringValue& target, const StringValue& regex,
222232
return StringValue::From(std::move(output), arena);
223233
}
224234

225-
} // namespace
226-
227235
absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) {
228236
CEL_RETURN_IF_ERROR(
229237
(BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
@@ -244,20 +252,80 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) {
244252
return absl::OkStatus();
245253
}
246254

247-
absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
248-
const RuntimeOptions& options) {
249-
if (options.enable_regex) {
250-
CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(registry));
255+
const Type& OptionalStringType() {
256+
static absl::NoDestructor<Type> kInstance(
257+
OptionalType(BuiltinsArena(), StringType()));
258+
return *kInstance;
259+
}
260+
261+
const Type& ListStringType() {
262+
static absl::NoDestructor<Type> kInstance(
263+
ListType(BuiltinsArena(), StringType()));
264+
return *kInstance;
265+
}
266+
267+
absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) {
268+
CEL_ASSIGN_OR_RETURN(
269+
FunctionDecl extract_decl,
270+
MakeFunctionDecl(
271+
"regex.extract",
272+
MakeOverloadDecl("regex_extract_string_string", OptionalStringType(),
273+
StringType(), StringType())));
274+
275+
CEL_ASSIGN_OR_RETURN(
276+
FunctionDecl extract_all_decl,
277+
MakeFunctionDecl(
278+
"regex.extractAll",
279+
MakeOverloadDecl("regex_extractAll_string_string", ListStringType(),
280+
StringType(), StringType())));
281+
282+
CEL_ASSIGN_OR_RETURN(
283+
FunctionDecl replace_decl,
284+
MakeFunctionDecl(
285+
"regex.replace",
286+
MakeOverloadDecl("regex_replace_string_string_string", StringType(),
287+
StringType(), StringType(), StringType()),
288+
MakeOverloadDecl("regex_replace_string_string_string_int",
289+
StringType(), StringType(), StringType(),
290+
StringType(), IntType())));
291+
292+
CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl));
293+
CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl));
294+
CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl));
295+
return absl::OkStatus();
296+
}
297+
298+
} // namespace
299+
300+
absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) {
301+
auto& runtime = cel::internal::down_cast<runtime_internal::RuntimeImpl&>(
302+
runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder));
303+
if (!runtime.expr_builder().optional_types_enabled()) {
304+
return absl::InvalidArgumentError(
305+
"regex extensions requires the optional types to be enabled");
306+
}
307+
if (runtime.expr_builder().options().enable_regex) {
308+
CEL_RETURN_IF_ERROR(
309+
RegisterRegexExtensionFunctions(builder.function_registry()));
251310
}
252311
return absl::OkStatus();
253312
}
254313

255314
absl::Status RegisterRegexExtensionFunctions(
256315
google::api::expr::runtime::CelFunctionRegistry* registry,
257316
const google::api::expr::runtime::InterpreterOptions& options) {
258-
return RegisterRegexExtensionFunctions(
259-
registry->InternalGetRegistry(),
260-
google::api::expr::runtime::ConvertToRuntimeOptions(options));
317+
if (!options.enable_regex) {
318+
return RegisterRegexExtensionFunctions(registry->InternalGetRegistry());
319+
}
320+
return absl::OkStatus();
321+
}
322+
323+
CheckerLibrary RegexExtCheckerLibrary() {
324+
return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls};
325+
}
326+
327+
CompilerLibrary RegexExtCompilerLibrary() {
328+
return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary());
261329
}
262330

263331
} // namespace cel::extensions

extensions/regex_ext.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,42 @@
7676
#define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_
7777

7878
#include "absl/status/status.h"
79+
#include "checker/type_checker_builder.h"
80+
#include "compiler/compiler.h"
7981
#include "eval/public/cel_function_registry.h"
8082
#include "eval/public/cel_options.h"
81-
#include "runtime/function_registry.h"
82-
#include "runtime/runtime_options.h"
83+
#include "runtime/runtime_builder.h"
8384

8485
namespace cel::extensions {
8586

8687
// Register extension functions for regular expressions.
8788
absl::Status RegisterRegexExtensionFunctions(
8889
google::api::expr::runtime::CelFunctionRegistry* registry,
8990
const google::api::expr::runtime::InterpreterOptions& options);
90-
absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
91-
const RuntimeOptions& options);
91+
absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder);
92+
93+
// Type check declarations for the regex extension library.
94+
// Provides decls for the following functions:
95+
//
96+
// regex.replace(target: str, pattern: str, replacement: str) -> str
97+
//
98+
// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str
99+
//
100+
// regex.extract(target: str, pattern: str) -> optional<str>
101+
//
102+
// regex.extractAll(target: str, pattern: str) -> list<str>
103+
CheckerLibrary RegexExtCheckerLibrary();
104+
105+
// Provides decls for the following functions:
106+
//
107+
// regex.replace(target: str, pattern: str, replacement: str) -> str
108+
//
109+
// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str
110+
//
111+
// regex.extract(target: str, pattern: str) -> optional<str>
112+
//
113+
// regex.extractAll(target: str, pattern: str) -> list<str>
114+
CompilerLibrary RegexExtCompilerLibrary();
92115

93116
} // namespace cel::extensions
94117
#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_

extensions/regex_ext_test.cc

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
#include "absl/status/status.h"
2323
#include "absl/status/status_matchers.h"
2424
#include "absl/status/statusor.h"
25+
#include "absl/strings/string_view.h"
26+
#include "checker/standard_library.h"
27+
#include "checker/validation_result.h"
2528
#include "common/value.h"
2629
#include "common/value_testing.h"
30+
#include "compiler/compiler.h"
31+
#include "compiler/compiler_factory.h"
2732
#include "extensions/protobuf/runtime_adapter.h"
2833
#include "internal/status_macros.h"
2934
#include "internal/testing.h"
@@ -33,7 +38,6 @@
3338
#include "runtime/optional_types.h"
3439
#include "runtime/reference_resolver.h"
3540
#include "runtime/runtime.h"
36-
#include "runtime/runtime_builder.h"
3741
#include "runtime/runtime_options.h"
3842
#include "runtime/standard_runtime_builder_factory.h"
3943
#include "google/protobuf/arena.h"
@@ -84,9 +88,7 @@ class RegexExtTest : public TestWithParam<RegexExtTestCase> {
8488
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
8589
IsOk());
8690
ASSERT_THAT(EnableOptionalTypes(builder), IsOk());
87-
ASSERT_THAT(
88-
RegisterRegexExtensionFunctions(builder.function_registry(), options),
89-
IsOk());
91+
ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk());
9092
ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build());
9193
}
9294

@@ -103,6 +105,23 @@ class RegexExtTest : public TestWithParam<RegexExtTestCase> {
103105
std::unique_ptr<const Runtime> runtime_;
104106
};
105107

108+
TEST_F(RegexExtTest, BuildFailsWithoutOptionalSupport) {
109+
RuntimeOptions options;
110+
options.enable_regex = true;
111+
options.enable_qualified_type_identifiers = true;
112+
113+
ASSERT_OK_AND_ASSIGN(auto builder,
114+
CreateStandardRuntimeBuilder(
115+
internal::GetTestingDescriptorPool(), options));
116+
ASSERT_THAT(
117+
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
118+
IsOk());
119+
// Optional types are NOT enabled.
120+
ASSERT_THAT(RegisterRegexExtensionFunctions(builder),
121+
StatusIs(absl::StatusCode::kInvalidArgument,
122+
HasSubstr("regex extensions requires the optional types "
123+
"to be enabled")));
124+
}
106125
std::vector<RegexExtTestCase> regexTestCases() {
107126
return {
108127
// Tests for extract Function
@@ -121,6 +140,11 @@ std::vector<RegexExtTestCase> regexTestCases() {
121140
"regex.extract('hello world', 'goodbye (.*)')"},
122141
{EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"},
123142
{EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"},
143+
{EvaluationType::kBoolTrue,
144+
"regex.extract('4122345432', '22').orValue('777') == '22'"},
145+
{EvaluationType::kBoolTrue,
146+
"regex.extract('4122345432', '22').or(optional.of('777')) == "
147+
"optional.of('22')"},
124148

125149
// Tests for extractAll Function
126150
{EvaluationType::kBoolTrue,
@@ -328,5 +352,57 @@ TEST_P(RegexExtTest, RegexExtTests) {
328352

329353
INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest,
330354
ValuesIn(regexTestCases()));
355+
356+
struct RegexCheckerTestCase {
357+
std::string expr_string;
358+
std::string error_substr;
359+
};
360+
361+
class RegexExtCheckerLibraryTest : public TestWithParam<RegexCheckerTestCase> {
362+
public:
363+
void SetUp() override {
364+
// Arrange: Configure the compiler.
365+
// Add the regex checker library to the compiler builder.
366+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<CompilerBuilder> compiler_builder,
367+
NewCompilerBuilder(descriptor_pool_));
368+
ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk());
369+
ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()),
370+
IsOk());
371+
ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build());
372+
}
373+
374+
const google::protobuf::DescriptorPool* descriptor_pool_ =
375+
internal::GetTestingDescriptorPool();
376+
std::unique_ptr<Compiler> compiler_;
377+
};
378+
379+
TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) {
380+
// Act & Assert: Compile the expression and validate the result.
381+
ASSERT_OK_AND_ASSIGN(ValidationResult result,
382+
compiler_->Compile(GetParam().expr_string));
383+
absl::string_view error_substr = GetParam().error_substr;
384+
EXPECT_EQ(result.IsValid(), error_substr.empty());
385+
386+
if (!error_substr.empty()) {
387+
EXPECT_THAT(result.FormatError(), HasSubstr(error_substr));
388+
}
389+
}
390+
391+
std::vector<RegexCheckerTestCase> createRegexCheckerParams() {
392+
return {
393+
{"regex.replace('abc', 'a', 's') == 'sbc'"},
394+
{"regex.replace('abc', 'a', 's') == 121",
395+
"found no matching overload for '_==_' applied to '(string, int)"},
396+
{"regex.replace('abc', 'j', '1', 2) == 9.0",
397+
"found no matching overload for '_==_' applied to '(string, double)"},
398+
{"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"},
399+
{"regex.extract('foo bar', 'f') == 121",
400+
"found no matching overload for '_==_' applied to "
401+
"'(optional_type(string), int)'"},
402+
};
403+
}
404+
405+
INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest,
406+
ValuesIn(createRegexCheckerParams()));
331407
} // namespace
332408
} // namespace cel::extensions

0 commit comments

Comments
 (0)