Thanks to visit codestin.com
Credit goes to github.com

Skip to content

- Add support for checked downcast on runtime to support extensions. #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
DEFAULT_VISIBILITY = [
"//eval:__subpackages__",
"//runtime:__subpackages__",
"//extensions:__subpackages__",
]

Expand Down
45 changes: 45 additions & 0 deletions runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ cc_library(
"//base:ast",
"//base:data",
"//base:handle",
"//internal:rtti",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -279,3 +280,47 @@ cc_library(
"@com_google_absl//absl/status",
],
)

cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
hdrs = ["constant_folding.h"],
deps = [
":runtime",
":runtime_builder",
"//base:memory",
"//eval/compiler:constant_folding",
"//internal:casts",
"//internal:rtti",
"//internal:status_macros",
"//runtime/internal:runtime_friend_access",
"//runtime/internal:runtime_impl",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

cc_test(
name = "constant_folding_test",
srcs = ["constant_folding_test.cc"],
deps = [
":activation",
":constant_folding",
":managed_value_factory",
":register_function_helper",
":runtime_builder",
":runtime_options",
":standard_runtime_builder_factory",
"//base:data",
"//base:function_adapter",
"//base:handle",
"//extensions/protobuf:runtime_adapter",
"//internal:testing",
"//parser",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
],
)
64 changes: 64 additions & 0 deletions runtime/constant_folding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/constant_folding.h"

#include "absl/base/macros.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "base/memory.h"
#include "eval/compiler/constant_folding.h"
#include "internal/casts.h"
#include "internal/rtti.h"
#include "internal/status_macros.h"
#include "runtime/internal/runtime_friend_access.h"
#include "runtime/internal/runtime_impl.h"
#include "runtime/runtime.h"
#include "runtime/runtime_builder.h"

namespace cel::extensions {
namespace {

using ::cel::internal::down_cast;
using ::cel::internal::TypeId;
using ::cel::runtime_internal::RuntimeFriendAccess;
using ::cel::runtime_internal::RuntimeImpl;

absl::StatusOr<RuntimeImpl*> RuntimeImplFromBuilder(RuntimeBuilder& builder) {
Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder);

if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId<RuntimeImpl>()) {
return absl::UnimplementedError(
"constant folding only supported on the default cel::Runtime "
"implementation.");
}

RuntimeImpl& runtime_impl = down_cast<RuntimeImpl&>(runtime);

return &runtime_impl;
}

} // namespace

absl::Status EnableConstantFolding(RuntimeBuilder& builder,
MemoryManager& memory_manager) {
CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl,
RuntimeImplFromBuilder(builder));
ABSL_ASSERT(runtime_impl != nullptr);
runtime_impl->expr_builder().AddProgramOptimizer(
runtime_internal::CreateConstantFoldingOptimizer(memory_manager));
return absl::OkStatus();
}

} // namespace cel::extensions
37 changes: 37 additions & 0 deletions runtime/constant_folding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_
#define THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_

#include "absl/status/status.h"
#include "base/memory.h"
#include "runtime/runtime_builder.h"

namespace cel::extensions {

// Enable constant folding in the runtime being built.
//
// Constant folding eagerly evaluates sub-expressions with all constant inputs
// at plan time to simplify the resulting program. User extensions functions are
// executed if they are eagerly bound.
//
// The provided memory manager must outlive the runtime object built
// from builder.
absl::Status EnableConstantFolding(RuntimeBuilder& builder,
MemoryManager& memory_manager);

} // namespace cel::extensions

#endif // THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_
143 changes: 143 additions & 0 deletions runtime/constant_folding_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/constant_folding.h"

#include <string>
#include <utility>
#include <vector>

#include "google/api/expr/v1alpha1/syntax.pb.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "base/function_adapter.h"
#include "base/handle.h"
#include "base/value.h"
#include "base/values/bool_value.h"
#include "base/values/int_value.h"
#include "base/values/string_value.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/testing.h"
#include "parser/parser.h"
#include "runtime/activation.h"
#include "runtime/managed_value_factory.h"
#include "runtime/register_function_helper.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"

namespace cel::extensions {
namespace {

using ::google::api::expr::v1alpha1::ParsedExpr;
using ::google::api::expr::parser::Parse;
using testing::HasSubstr;
using cel::internal::StatusIs;

using ValueMatcher = testing::Matcher<Handle<Value>>;

struct TestCase {
std::string name;
std::string expression;
ValueMatcher result_matcher;
absl::Status status;
};

MATCHER_P(IsIntValue, expected, "") {
const Handle<Value>& value = arg;
return value->Is<IntValue>() && value->As<IntValue>().value() == expected;
}

MATCHER_P(IsBoolValue, expected, "") {
const Handle<Value>& value = arg;
return value->Is<BoolValue>() && value->As<BoolValue>().value() == expected;
}

MATCHER_P(IsErrorValue, expected_substr, "") {
const Handle<Value>& value = arg;
return value->Is<ErrorValue>() &&
absl::StrContains(value->As<ErrorValue>().value().message(),
expected_substr);
}

class ConstantFoldingExtTest : public testing::TestWithParam<TestCase> {};

TEST_P(ConstantFoldingExtTest, Runner) {
RuntimeOptions options;
const TestCase& test_case = GetParam();
ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder,
CreateStandardRuntimeBuilder(options));

auto status = RegisterHelper<BinaryFunctionAdapter<
absl::StatusOr<Handle<Value>>, const StringValue&, const StringValue&>>::
RegisterGlobalOverload(
"prepend",
[](ValueFactory& f, const StringValue& value,
const StringValue& prefix) {
return StringValue::Concat(f, prefix, value);
},
builder.function_registry());
ASSERT_OK(status);

ASSERT_OK(EnableConstantFolding(builder, MemoryManager::Global()));

ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());

ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression));

ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(
*runtime, parsed_expr));

ManagedValueFactory value_factory(program->GetTypeProvider(),
MemoryManager::Global());
Activation activation;

auto result = program->Evaluate(activation, value_factory.get());
if (test_case.status.ok()) {
ASSERT_OK_AND_ASSIGN(Handle<Value> value, std::move(result));

EXPECT_THAT(value, test_case.result_matcher);
return;
}

EXPECT_THAT(result.status(), StatusIs(test_case.status.code(),
HasSubstr(test_case.status.message())));
}

INSTANTIATE_TEST_SUITE_P(
Cases, ConstantFoldingExtTest,
testing::ValuesIn(std::vector<TestCase>{
{"sum", "1 + 2 + 3", IsIntValue(6)},
{"list_create", "[1, 2, 3, 4].filter(x, x < 4).size()", IsIntValue(3)},
{"string_concat", "('12' + '34' + '56' + '78' + '90').size()",
IsIntValue(10)},
{"comprehension", "[1, 2, 3, 4].exists(x, x in [4, 5, 6, 7])",
IsBoolValue(true)},
{"nested_comprehension",
"[1, 2, 3, 4].exists(x, [1, 2, 3, 4].all(y, y <= x))",
IsBoolValue(true)},
{"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))",
IsErrorValue("No matching overloads")},
// TODO(uncreated-issue/32): Depends on map creation
// {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", 2},
{"custom_function", "prepend('def', 'abc') == 'abcdef'",
IsBoolValue(true)}}),

[](const testing::TestParamInfo<TestCase>& info) {
return info.param.name;
});

} // namespace
} // namespace cel::extensions
11 changes: 11 additions & 0 deletions runtime/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ cc_test(
],
)

cc_library(
name = "runtime_friend_access",
hdrs = ["runtime_friend_access.h"],
deps = [
"//internal:rtti",
"//runtime",
"//runtime:runtime_builder",
],
)

cc_library(
name = "runtime_impl",
srcs = ["runtime_impl.cc"],
Expand All @@ -65,6 +75,7 @@ cc_library(
"//base:handle",
"//eval/compiler:flat_expr_builder",
"//eval/eval:evaluator_core",
"//internal:rtti",
"//internal:status_macros",
"//runtime",
"//runtime:activation_interface",
Expand Down
45 changes: 45 additions & 0 deletions runtime/internal/runtime_friend_access.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_
#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_

#include "internal/rtti.h"
#include "runtime/runtime.h"
#include "runtime/runtime_builder.h"

namespace cel::runtime_internal {

// Provide accessors for friend-visibility internal runtime details.
//
// CEL supported runtime extensions need implementation specific details to work
// correctly. We restrict access to prevent external usages since we don't
// guarantee stability on the implementation details.
class RuntimeFriendAccess {
public:
// Access underlying runtime instance.
static Runtime& GetMutableRuntime(RuntimeBuilder& builder) {
return builder.runtime();
}

// Return the internal type_id for the runtime instance for checked down
// casting.
static internal::TypeInfo RuntimeTypeId(Runtime& runtime) {
return runtime.TypeId();
}
};

} // namespace cel::runtime_internal

#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_RUNTIME_EXTENSIONS_FRIEND_ACCESS_H_
Loading