Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add support for recursively planned bind expressions. #696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,15 @@ cc_test(
"//eval/public/structs:cel_proto_wrapper",
"//eval/public/structs:protobuf_descriptor_type_provider",
"//eval/public/testing:matchers",
"//extensions:bindings_ext",
"//internal:testing",
"//parser",
"//parser:macro",
"//runtime:runtime_options",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto",
"@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto",
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
Expand Down
6 changes: 4 additions & 2 deletions eval/compiler/cel_expression_builder_flat_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ CelExpressionBuilderFlatImpl::CreateExpressionImpl(
}
}
if (flat_expr_builder_.options().max_recursion_depth != 0 &&
impl.path().size() == 1 &&
impl.path().front()->GetNativeTypeId() ==
!impl.subexpressions().empty() &&
// mainline expression is exactly one recursive step.
impl.subexpressions().front().size() == 1 &&
impl.subexpressions().front().front()->GetNativeTypeId() ==
cel::NativeTypeId::For<WrappedDirectStep>()) {
return CelExpressionRecursiveImpl::Create(std::move(impl));
}
Expand Down
32 changes: 28 additions & 4 deletions eval/compiler/cel_expression_builder_flat_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
#include "eval/compiler/cel_expression_builder_flat_impl.h"

#include <cstdint>
#include <iterator>
#include <memory>
#include <string>
#include <vector>

#include "google/api/expr/v1alpha1/checked.pb.h"
#include "google/api/expr/v1alpha1/syntax.pb.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "eval/eval/cel_expression_flat_impl.h"
#include "eval/public/activation.h"
#include "eval/public/builtin_func_registrar.h"
Expand All @@ -35,7 +39,9 @@
#include "eval/public/structs/cel_proto_wrapper.h"
#include "eval/public/structs/protobuf_descriptor_type_provider.h"
#include "eval/public/testing/matchers.h"
#include "extensions/bindings_ext.h"
#include "internal/testing.h"
#include "parser/macro.h"
#include "parser/parser.h"
#include "runtime/runtime_options.h"
#include "proto/test/v1/proto3/test_all_types.pb.h"
Expand All @@ -51,7 +57,9 @@ using ::google::api::expr::v1alpha1::CheckedExpr;
using ::google::api::expr::v1alpha1::Expr;
using ::google::api::expr::v1alpha1::ParsedExpr;
using ::google::api::expr::v1alpha1::SourceInfo;
using ::google::api::expr::parser::Macro;
using ::google::api::expr::parser::Parse;
using ::google::api::expr::parser::ParseWithMacros;
using ::google::api::expr::test::v1::proto3::NestedTestAllTypes;
using ::google::api::expr::test::v1::proto3::TestAllTypes;
using testing::_;
Expand Down Expand Up @@ -94,9 +102,19 @@ struct RecursiveTestCase {

class RecursivePlanTest : public ::testing::TestWithParam<RecursiveTestCase> {};

absl::StatusOr<ParsedExpr> ParseWithBind(absl::string_view cel) {
static const std::vector<Macro>* kMacros = []() {
auto* result = new std::vector<Macro>(Macro::AllMacros());
absl::c_copy(cel::extensions::bindings_macros(),
std::back_inserter(*result));
return result;
}();
return ParseWithMacros(cel, *kMacros, "<input>");
}

TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) {
const RecursiveTestCase& test_case = GetParam();
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expr));
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
cel::RuntimeOptions options;
options.container = "google.api.expr.test.v1.proto3";
google::protobuf::Arena arena;
Expand Down Expand Up @@ -135,7 +153,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) {

TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) {
const RecursiveTestCase& test_case = GetParam();
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expr));
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
cel::RuntimeOptions options;
options.container = "google.api.expr.test.v1.proto3";
google::protobuf::Arena arena;
Expand Down Expand Up @@ -180,7 +198,7 @@ TEST_P(RecursivePlanTest, Disabled) {
google::protobuf::LinkMessageReflection<TestAllTypes>();

const RecursiveTestCase& test_case = GetParam();
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expr));
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
cel::RuntimeOptions options;
options.container = "google.api.expr.test.v1.proto3";
google::protobuf::Arena arena;
Expand Down Expand Up @@ -251,7 +269,13 @@ INSTANTIATE_TEST_SUITE_P(
"NestedTestAllTypes{payload: TestAllTypes{single_int64: "
"-42}}.payload.single_int64",
test::IsCelInt64(-42)},
}),
{"bind", R"(cel.bind(x, "1", x + x + x + x))",
test::IsCelString("1111")},
{"nested_bind", R"(cel.bind(x, 20, cel.bind(y, 30, x + y)))",
test::IsCelInt64(50)},
{"bind_with_comprehensions",
R"(cel.bind(x, [1, 2], cel.bind(y, x.map(z, z * 2), y.exists(z, z == 4))))",
test::IsCelBool(true)}}),

[](const testing::TestParamInfo<RecursiveTestCase>& info) -> std::string {
return info.param.test_name;
Expand Down
54 changes: 48 additions & 6 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,24 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
SlotLookupResult slot = LookupSlot(path);

if (slot.subexpression >= 0) {
AddStep(
CreateCheckLazyInitStep(slot.slot, slot.subexpression, expr->id()));
AddStep(CreateAssignSlotStep(slot.slot));
auto* subexpression =
program_builder_.GetExtractedSubexpression(slot.subexpression);
if (subexpression == nullptr) {
SetProgressStatusError(
absl::InternalError("bad subexpression reference"));
return;
}
if (subexpression->IsRecursive()) {
const auto& program = subexpression->recursive_program();
SetRecursiveStep(
CreateDirectLazyInitStep(slot.slot, program.step.get(), expr->id()),
program.depth + 1);
} else {
// Off by one since mainline expression will be index 0.
AddStep(CreateCheckLazyInitStep(slot.slot, slot.subexpression + 1,
expr->id()));
AddStep(CreateAssignSlotStep(slot.slot));
}
return;
} else if (slot.slot >= 0) {
if (options_.max_recursion_depth != 0) {
Expand Down Expand Up @@ -843,6 +858,33 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
}
}

void MaybeMakeBindRecursive(
const cel::ast_internal::Expr* expr,
const cel::ast_internal::Comprehension* comprehension, size_t accu_slot) {
if (options_.max_recursion_depth == 0) {
return;
}

auto* result_plan =
program_builder_.GetSubexpression(&comprehension->result());

if (result_plan == nullptr || !result_plan->IsRecursive()) {
return;
}

int result_depth = result_plan->recursive_program().depth;

if (options_.max_recursion_depth > 0 &&
result_depth >= options_.max_recursion_depth) {
return;
}

auto program = result_plan->ExtractRecursiveProgram();
SetRecursiveStep(
CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()),
result_depth + 1);
}

void MaybeMakeComprehensionRecursive(
const cel::ast_internal::Expr* expr,
const cel::ast_internal::Comprehension* comprehension, size_t iter_slot,
Expand Down Expand Up @@ -1443,8 +1485,7 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
return absl::InternalError("Failed to extract subexpression");
}

// off by one since mainline expression is handled separately.
record.subexpression = index + 1;
record.subexpression = index;

record.visitor->MarkAccuInitExtracted();

Expand Down Expand Up @@ -1806,7 +1847,8 @@ void ComprehensionVisitor::PostVisitArgTrivial(

void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) {
if (is_trivial_) {
// TODO(uncreated-issue/67): need to add mechanism for lazy eval for binds.
visitor_->MaybeMakeBindRecursive(expr, &expr->comprehension_expr(),
accu_slot_);
return;
}
visitor_->MaybeMakeComprehensionRecursive(expr, &expr->comprehension_expr(),
Expand Down
11 changes: 11 additions & 0 deletions eval/compiler/flat_expr_builder_extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ class ProgramBuilder {
absl::Nullable<Subexpression*> GetSubexpression(
const cel::ast_internal::Expr* expr);

// Return the extracted subexpression mapped to the given index.
//
// Returns nullptr if the mapping doesn't exist
absl::Nullable<Subexpression*> GetExtractedSubexpression(size_t index) {
if (index >= extracted_subexpressions_.size()) {
return nullptr;
}

return extracted_subexpressions_[index].get();
}

// Return index to the extracted subexpression.
//
// Returns -1 if the subexpression is not found.
Expand Down
6 changes: 6 additions & 0 deletions eval/eval/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1069,9 +1069,15 @@ cc_library(
srcs = ["lazy_init_step.cc"],
hdrs = ["lazy_init_step.h"],
deps = [
":attribute_trail",
":direct_expression_step",
":evaluator_core",
":expression_step_base",
"//common:value",
"//internal:status_macros",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto",
],
)

Expand Down
4 changes: 2 additions & 2 deletions eval/eval/cel_expression_flat_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ absl::StatusOr<CelValue> CelExpressionFlatImpl::Evaluate(

absl::StatusOr<std::unique_ptr<CelExpressionRecursiveImpl>>
CelExpressionRecursiveImpl::Create(FlatExpression flat_expr) {
if (flat_expr.path().size() != 1 ||
if (flat_expr.path().empty() ||
flat_expr.path().front()->GetNativeTypeId() !=
cel::NativeTypeId::For<WrappedDirectStep>()) {
return absl::InvalidArgumentError(absl::StrCat(
"Expected a single recursive program step", flat_expr.path().size()));
"Expected a recursive program step", flat_expr.path().size()));
}

auto* instance = new CelExpressionRecursiveImpl(std::move(flat_expr));
Expand Down
4 changes: 4 additions & 0 deletions eval/eval/evaluator_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ class FlatExpression {

const ExecutionPath& path() const { return path_; }

absl::Span<const ExecutionPathView> subexpressions() const {
return subexpressions_;
}

const cel::RuntimeOptions& options() const { return options_; }

size_t comprehension_slots_size() const { return comprehension_slots_size_; }
Expand Down
73 changes: 73 additions & 0 deletions eval/eval/lazy_init_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@
#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>

#include "google/api/expr/v1alpha1/value.pb.h"
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "common/value.h"
#include "eval/eval/attribute_trail.h"
#include "eval/eval/direct_expression_step.h"
#include "eval/eval/evaluator_core.h"
#include "eval/eval/expression_step_base.h"
#include "internal/status_macros.h"

namespace google::api::expr::runtime {

namespace {

using ::cel::Value;

class CheckLazyInitStep : public ExpressionStepBase {
public:
CheckLazyInitStep(size_t slot_index, size_t subexpression_index,
Expand All @@ -52,6 +61,57 @@ class CheckLazyInitStep : public ExpressionStepBase {
size_t subexpression_index_;
};

class DirectCheckLazyInitStep : public DirectExpressionStep {
public:
DirectCheckLazyInitStep(size_t slot_index,
const DirectExpressionStep* subexpression,
int64_t expr_id)
: DirectExpressionStep(expr_id),
slot_index_(slot_index),
subexpression_(subexpression) {}

absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
AttributeTrail& attribute) const override {
auto* slot = frame.comprehension_slots().Get(slot_index_);
if (slot != nullptr) {
result = slot->value;
attribute = slot->attribute;
return absl::OkStatus();
}

CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute));
frame.comprehension_slots().Set(slot_index_, result, attribute);

return absl::OkStatus();
}

private:
size_t slot_index_;
absl::Nonnull<const DirectExpressionStep*> subexpression_;
};

class BindStep : public DirectExpressionStep {
public:
BindStep(size_t slot_index,
std::unique_ptr<DirectExpressionStep> subexpression, int64_t expr_id)
: DirectExpressionStep(expr_id),
slot_index_(slot_index),
subexpression_(std::move(subexpression)) {}

absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
AttributeTrail& attribute) const override {
CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute));

frame.comprehension_slots().ClearSlot(slot_index_);

return absl::OkStatus();
}

private:
size_t slot_index_;
std::unique_ptr<DirectExpressionStep> subexpression_;
};

class AssignSlotStep : public ExpressionStepBase {
public:
explicit AssignSlotStep(size_t slot_index, bool should_pop)
Expand Down Expand Up @@ -95,6 +155,19 @@ class ClearSlotStep : public ExpressionStepBase {

} // namespace

std::unique_ptr<DirectExpressionStep> CreateDirectBindStep(
size_t slot_index, std::unique_ptr<DirectExpressionStep> expression,
int64_t expr_id) {
return std::make_unique<BindStep>(slot_index, std::move(expression), expr_id);
}

std::unique_ptr<DirectExpressionStep> CreateDirectLazyInitStep(
size_t slot_index, absl::Nonnull<const DirectExpressionStep*> subexpression,
int64_t expr_id) {
return std::make_unique<DirectCheckLazyInitStep>(slot_index, subexpression,
expr_id);
}

std::unique_ptr<ExpressionStep> CreateCheckLazyInitStep(
size_t slot_index, size_t subexpression_index, int64_t expr_id) {
return std::make_unique<CheckLazyInitStep>(slot_index, subexpression_index,
Expand Down
13 changes: 13 additions & 0 deletions eval/eval/lazy_init_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,23 @@
#include <cstdint>
#include <memory>

#include "absl/base/nullability.h"
#include "eval/eval/direct_expression_step.h"
#include "eval/eval/evaluator_core.h"

namespace google::api::expr::runtime {

// Creates a step representing a Bind expression.
std::unique_ptr<DirectExpressionStep> CreateDirectBindStep(
size_t slot_index, std::unique_ptr<DirectExpressionStep> expression,
int64_t expr_id);

// Creates a direct step representing accessing a lazily evaluated alias from
// a bind or block.
std::unique_ptr<DirectExpressionStep> CreateDirectLazyInitStep(
size_t slot_index, absl::Nonnull<const DirectExpressionStep*> subexpression,
int64_t expr_id);

// Creates a guard step that checks that an alias is initialized.
// If it is, push to stack and jump to the step that depends on the value.
// Otherwise, run the initialization routine (which pushes the value to top of
Expand Down
Loading