diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 25550616a..26c7b543f 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -145,6 +145,7 @@ cc_library( "//common:container", "//common:decl", "//common:expr", + "//common:standard_definitions", "//common:type", "//common:type_kind", "//internal:lexis", @@ -238,6 +239,7 @@ cc_library( deps = [ ":format_type_name", "//common:decl", + "//common:standard_definitions", "//common:type", "//common:type_kind", "@com_google_absl//absl/container:flat_hash_map", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1ce871255..6b6b051b1 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -52,6 +52,7 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -894,8 +895,12 @@ const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( if (decl == nullptr) { return true; } + bool is_logical_op = (candidate == cel::StandardFunctions::kAnd || + candidate == cel::StandardFunctions::kOr) && + arg_count >= 2; for (const auto& ovl : decl->overloads()) { - if (ovl.member() == is_receiver && ovl.args().size() == arg_count) { + if (ovl.member() == is_receiver && + (ovl.args().size() == arg_count || is_logical_op)) { return false; } } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 893f0689d..61ef7d55b 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -55,6 +55,7 @@ #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" namespace cel { namespace checker_internal { @@ -1471,6 +1472,93 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(DynTypeSpec()))))))); } +struct VariadicLogicalCheckerTestCase { + std::string expr; +}; + +class VariadicLogicalCheckerTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalCheckerTest, Check) { + const auto& test_case = GetParam(); + + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + auto checker_builder = impl.ToBuilder(); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(parsed_ast))); + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveType::kBool))))); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalChecker, VariadicLogicalCheckerTest, + testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"}, + VariadicLogicalCheckerTestCase{"a && b && c && d"}, + VariadicLogicalCheckerTestCase{"a || b || c || d"}, + VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"}, + VariadicLogicalCheckerTestCase{"a && b && c"}, + VariadicLogicalCheckerTestCase{"a || b || c"}, + VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"}, + VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"})); + +TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) { + cel::expr::ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + } + } + )pb", + &parsed_expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + impl.Check(std::move(parsed_ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference"))); +} + TEST(TypeCheckerImplTest, ExpectedTypeMatches) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 5b909d982..1a87d9e15 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -537,21 +538,28 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, bool is_receiver) { std::optional result_type; + bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd || + decl.name() == cel::StandardFunctions::kOr) && + argument_types.size() >= 2; + std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { if (ovl.member() != is_receiver || - argument_types.size() != ovl.args().size()) { + (!is_logical_op && argument_types.size() != ovl.args().size())) { continue; } auto call_type_instance = InstantiateFunctionOverload(*this, ovl); - ABSL_DCHECK_EQ(argument_types.size(), - call_type_instance.param_types.size()); + if (!is_logical_op) { + ABSL_DCHECK_EQ(argument_types.size(), + call_type_instance.param_types.size()); + } bool is_match = true; AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { + int param_index = is_logical_op ? 0 : i; if (!assignability_context.IsAssignable( - argument_types[i], call_type_instance.param_types[i])) { + argument_types[i], call_type_instance.param_types[param_index])) { is_match = false; break; } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1e3f4ecd3..d6ccdf040 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -232,7 +232,7 @@ class BinaryCondVisitor : public CondVisitor { private: FlatExprVisitor* visitor_; const BinaryCond cond_; - Jump jump_step_; + std::vector jump_steps_; bool short_circuiting_; }; @@ -622,7 +622,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPreVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); } } } @@ -639,7 +639,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPostVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -657,7 +657,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_bind && (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { - SetProgressStatusError( + SetProgressStatusIfError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } @@ -666,7 +666,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (block.current_binding == &expr) { int index = program_builder_.ExtractSubexpression(&expr); if (index == -1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("failed to extract subexpression")); return; } @@ -686,7 +686,7 @@ class FlatExprVisitor : public cel::AstVisitor { ConvertConstant(const_expr, cel::NewDeleteAllocator()); if (!converted_value.ok()) { - SetProgressStatusError(converted_value.status()); + SetProgressStatusIfError(converted_value.status()); return; } @@ -722,13 +722,13 @@ class FlatExprVisitor : public cel::AstVisitor { if (absl::ConsumePrefix(&index_suffix, "@index")) { size_t index; if (!absl::SimpleAtoi(index_suffix, &index)) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("bad @index")))); return {-1, -1}; } if (index >= block.size) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "invalid @index greater than number of bindings: ", @@ -736,7 +736,7 @@ class FlatExprVisitor : public cel::AstVisitor { return {-1, -1}; } if (index >= block.current_index) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "@index references current or future binding: ", index, @@ -754,7 +754,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (record.iter_var_in_scope && record.comprehension->iter_var() == path) { if (record.is_optimizable_bind) { - SetProgressStatusError(issue_collector_.AddIssue( + SetProgressStatusIfError(issue_collector_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "Unexpected iter_var access in trivial comprehension")))); return {-1, -1}; @@ -781,7 +781,7 @@ class FlatExprVisitor : public cel::AstVisitor { // If we see a CSE generated comprehension variable that was not // resolvable through the normal comprehension scope resolution, reject it // now rather than surfacing errors at activation time. - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("out of scope reference to CSE " "generated comprehension variable")))); @@ -811,7 +811,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto* subexpression = program_builder_.GetExtractedSubexpression(slot.subexpression); if (subexpression == nullptr) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InternalError("bad subexpression reference")); return; } @@ -965,7 +965,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "unexpected number of dependencies for select operation.")); return; } @@ -1022,7 +1022,7 @@ class FlatExprVisitor : public cel::AstVisitor { // cel.@block if (block_.has_value()) { // There can only be one for now. - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("multiple cel.@block are not allowed")); return; } @@ -1030,17 +1030,17 @@ class FlatExprVisitor : public cel::AstVisitor { BlockInfo& block = *block_; block.in = true; if (call_expr.args().empty()) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing list of bound expressions")); return; } if (call_expr.args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing bound expression")); return; } if (!call_expr.args()[0].has_list_expr()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: first argument " "is not a list of bound expressions")); return; @@ -1051,7 +1051,7 @@ class FlatExprVisitor : public cel::AstVisitor { block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: list of bound " "expressions contains an optional")); return; @@ -1093,7 +1093,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); return; } @@ -1109,7 +1109,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (condition_plan == nullptr || !condition_plan->IsRecursive() || left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1126,45 +1126,52 @@ class FlatExprVisitor : public cel::AstVisitor { } void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (expr->call_expr().args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + int args_size = expr->call_expr().args().size(); + if (args_size < 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); return; } - const cel::Expr* left_expr = &expr->call_expr().args()[0]; - const cel::Expr* right_expr = &expr->call_expr().args()[1]; - auto* left_plan = program_builder_.GetSubexpression(left_expr); - auto* right_plan = program_builder_.GetSubexpression(right_expr); - - if (left_plan == nullptr || !left_plan->IsRecursive() || - right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + auto* current_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[0]); + if (current_plan == nullptr || !current_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } + int current_depth = current_plan->recursive_program().depth; + std::unique_ptr current_step = + current_plan->ExtractRecursiveProgram().step; - int max_depth = std::max({0, left_plan->recursive_program().depth, - right_plan->recursive_program().depth}); - - if (is_or) { - SetRecursiveStep( - CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); - } else { - SetRecursiveStep( - CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); + for (int i = 1; i < args_size; ++i) { + auto* next_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[i]); + if (next_plan == nullptr || !next_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + current_depth = + std::max(current_depth, next_plan->recursive_program().depth); + std::unique_ptr next_step = + next_plan->ExtractRecursiveProgram().step; + if (is_or) { + current_step = + CreateDirectOrStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } else { + current_step = + CreateDirectAndStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } + current_depth++; } + SetRecursiveStep(std::move(current_step), current_depth); } void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); return; } @@ -1176,7 +1183,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, left_plan->recursive_program().depth, @@ -1200,7 +1207,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1234,7 +1241,7 @@ class FlatExprVisitor : public cel::AstVisitor { loop_plan == nullptr || !loop_plan->IsRecursive() || condition_plan == nullptr || !condition_plan->IsRecursive() || result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1462,7 +1469,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( + SetProgressStatusIfError(comprehension_stack_.back().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.back().expr)); } @@ -1524,7 +1531,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateList expr")); return; } @@ -1547,7 +1554,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { - SetProgressStatusError(status_or_resolved_fields.status()); + SetProgressStatusIfError(status_or_resolved_fields.status()); return; } std::string resolved_name = @@ -1558,7 +1565,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1599,7 +1606,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1661,7 +1668,7 @@ class FlatExprVisitor : public cel::AstVisitor { "No overloads provided for FunctionStep creation"), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -1692,7 +1699,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (step.ok()) { return AddStep(*std::move(step)); } else { - SetProgressStatusError(step.status()); + SetProgressStatusIfError(step.status()); } return nullptr; } @@ -1711,19 +1718,19 @@ class FlatExprVisitor : public cel::AstVisitor { return; } if (program_builder_.current() == nullptr) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "CEL AST traversal out of order in flat_expr_builder.")); return; } program_builder_.current()->set_recursive_program(std::move(step), depth); if (depth > max_recursion_depth_) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat("Maximum recursion depth of ", options_.max_recursion_depth, " exceeded"))); } } - void SetProgressStatusError(const absl::Status& status) { + void SetProgressStatusIfError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -1765,7 +1772,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (valid_expression) { return true; } - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat(error_message, message_parts...))); return false; } @@ -1947,7 +1954,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); return CallHandlerResult::kIntercepted; } @@ -1974,7 +1981,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin not operator")); return CallHandlerResult::kIntercepted; } @@ -1997,7 +2004,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected number of args for builtin " "@not_strictly_false operator")); return CallHandlerResult::kIntercepted; @@ -2016,7 +2023,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( ABSL_DCHECK(call_expr.function() == kBlock); if (!block_.has_value() || block_->expr != &expr || call_expr.args().size() != 2 || call_expr.has_target()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected call to internal cel.@block")); return CallHandlerResult::kIntercepted; } @@ -2101,7 +2108,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin equality operator")); return CallHandlerResult::kIntercepted; } @@ -2126,7 +2133,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin 'in' operator")); return CallHandlerResult::kIntercepted; } @@ -2164,13 +2171,14 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } - if (short_circuiting_ && arg_num == 0 && + const int last_arg_index = expr->call_expr().args().size() - 1; + if (short_circuiting_ && arg_num < last_arg_index && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, // jump over the second branch and provide result of the first argument as // final output. - // Retain a pointer to the jump step so we can update the target after - // planning the second argument. + // Retain pointers to the jump steps so we can update the target after + // planning the next arguments. std::unique_ptr jump_step; switch (cond_) { case BinaryCond::kAnd: @@ -2185,7 +2193,7 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2215,7 +2223,7 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2243,28 +2251,36 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; - case BinaryCond::kOptionalOr: - visitor_->AddStep( - CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); - break; - case BinaryCond::kOptionalOrValue: - visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); - break; - default: - ABSL_UNREACHABLE(); + int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) + ? expr->call_expr().args().size() + : 2; + for (int i = 0; i < args_count - 1; ++i) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + case BinaryCond::kOptionalOr: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); + break; + case BinaryCond::kOptionalOrValue: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); + break; + default: + ABSL_UNREACHABLE(); + } } if (short_circuiting_) { // If short-circuiting is enabled, point the conditional jump past the // boolean operator step. - visitor_->SetProgressStatusError( - jump_step_.set_target(visitor_->GetCurrentIndex())); + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } @@ -2321,7 +2337,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->ValidateOrError( jump_to_second_.exists(), "Error configuring ternary operator: jump_to_second_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } @@ -2339,13 +2355,13 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->ValidateOrError( error_jump_.exists(), "Error configuring ternary operator: error_jump_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( error_jump_.set_target(visitor_->GetCurrentIndex())); } if (visitor_->ValidateOrError( jump_after_first_.exists(), "Error configuring ternary operator: jump_after_first_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } } @@ -2403,7 +2419,8 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } Jump jump_helper(index, jump_to_next); - visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); + visitor_->SetProgressStatusIfError( + jump_helper.set_target(next_step_pos_)); // Set offsets jumping to the result step. if (cond_step_) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index d84007485..62ef1ded3 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -469,7 +469,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -482,10 +482,10 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ operand{ ident_expr {name: 'var'} } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -498,11 +498,11 @@ TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ field: 'field' operand { id: 1 } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -515,7 +515,8 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -527,10 +528,10 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{accu_var: "a"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -542,12 +543,12 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: "a" iter_var: "b"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -559,7 +560,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -567,7 +568,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -579,7 +580,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -590,7 +591,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -602,7 +603,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -616,7 +617,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { const_expr {bool_value: false} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -628,7 +629,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -665,7 +666,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -683,7 +684,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; SourceInfo source_info; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( call_expr { function: "_&&_" args { @@ -697,7 +698,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -909,7 +910,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( expr { id: 1 call_expr { @@ -928,7 +929,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -946,7 +947,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CheckedExpr expr; // `foo.var1` && `bar.var2` - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -988,7 +989,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1008,7 +1009,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { CheckedExpr expr; // ext.and(var1, bar.var2) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 1 value { @@ -1057,7 +1058,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1082,7 +1083,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CheckedExpr expr; // && . - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -1125,7 +1126,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1160,7 +1161,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { CheckedExpr expr; // {`var1`: 'hello'} - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 3 value { @@ -1190,7 +1191,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1213,7 +1214,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { Expr expr; SourceInfo source_info; // {}[0].all(x, x) should evaluate OK but return an error value - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1278,7 +1279,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1295,7 +1296,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1349,7 +1350,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1721,7 +1722,7 @@ TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { TEST(FlatExprBuilderTest, MapFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1731,7 +1732,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { field: "string_int32_map" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -1765,7 +1766,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { TEST(FlatExprBuilderTest, RepeatedFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1775,7 +1776,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { field: "int32_list" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -2900,6 +2901,234 @@ TEST(FlatExprBuilderTest, BlockNested) { HasSubstr("multiple cel.@block are not allowed"))); } +struct VariadicLogicalEvalTestCase { + std::string label; + std::string expr; + std::string a_val; + std::string b_val; + std::string c_val; + std::string expected_type; // "bool", "error", "unknown" + bool expected_bool = false; +}; + +class FlatExprBuilderVariadicLogicalTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + std::vector unknown_patterns; + + // Set up variables: + auto insert_value = [&](absl::string_view name, const std::string& val) { + if (val == "true") { + activation.InsertValue(name, CelValue::CreateBool(true)); + } else if (val == "false") { + activation.InsertValue(name, CelValue::CreateBool(false)); + } else if (val == "error") { + activation.InsertValue(name, CreateErrorValue(&arena, "test error")); + } else if (val == "unknown1" || val == "unknown2") { + activation.InsertValue(name, CelValue::CreateBool(true)); + unknown_patterns.push_back(CreateCelAttributePattern(name, {})); + } + }; + + insert_value("a", test_case.a_val); + insert_value("b", test_case.b_val); + insert_value("c", test_case.c_val); + + if (!unknown_patterns.empty()) { + activation.set_unknown_attribute_patterns(std::move(unknown_patterns)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + if (test_case.expected_type == "bool") { + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_EQ(result.BoolOrDie(), test_case.expected_bool); + } else if (test_case.expected_type == "error") { + EXPECT_TRUE(result.IsError()) << result.DebugString(); + } else if (test_case.expected_type == "unknown") { + EXPECT_TRUE(result.IsUnknownSet()) << result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderVariadicLogicalTest, FlatExprBuilderVariadicLogicalTest, + testing::Values( + VariadicLogicalEvalTestCase{"AND_AllTrue", "a && b && c", "true", + "true", "true", "bool", true}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFalse", "a && b && c", + "true", "false", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFirstFalse", "a && b && c", + "false", "unset", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"OR_AllFalse", "a || b || c", "false", + "false", "false", "bool", false}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitTrue", "a || b || c", + "false", "true", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitFirstTrue", "a || b || c", + "true", "unset", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Error", "a && b && c", "true", "error", + "true", "error"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeError", + "a && b && c", "false", "error", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Error", "a || b || c", "false", "error", + "false", "error"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeError", "a || b || c", + "true", "error", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Unknown", "a && b && c", "true", + "unknown1", "true", "unknown"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeUnknown", + "a && b && c", "false", "unknown1", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Unknown", "a || b || c", "false", + "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeUnknown", + "a || b || c", "true", "unknown1", "unset", + "bool", true}, + VariadicLogicalEvalTestCase{"AND_UnknownAggregation", "a && b && c", + "unknown1", "unknown2", "true", "unknown"}, + VariadicLogicalEvalTestCase{"OR_UnknownAggregation", "a || b || c", + "unknown1", "unknown2", "false", "unknown"}, + VariadicLogicalEvalTestCase{"Exists_True", "[a, b, c].exists(x, x)", + "false", "false", "true", "bool", true}, + VariadicLogicalEvalTestCase{"Exists_Unknown", "[a, b, c].exists(x, x)", + "false", "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"All_False", "[a, b, c].all(x, x)", "true", + "true", "false", "bool", false}, + VariadicLogicalEvalTestCase{"All_Unknown", "[a, b, c].all(x, x)", + "true", "unknown1", "true", "unknown"})); + +struct RecursionDepthTestCase { + std::string label; + std::string expr; + int max_recursion_depth; + absl::StatusCode expected_status_code; + std::string expected_error_msg; +}; + +class FlatExprBuilderRecursionDepthTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderRecursionDepthTest, CheckRecursionLimit) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = test_case.max_recursion_depth; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + auto result = + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()); + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, StatusIs(test_case.expected_status_code, + HasSubstr(test_case.expected_error_msg))); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderRecursionDepthTest, FlatExprBuilderRecursionDepthTest, + testing::Values( + RecursionDepthTestCase{"AndChildLimitExceeded", "(1 + 1) && true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"AndParentLimitExceeded", "(1 + 1) && true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"AndLimitSuccess", "(1 + 1) && true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessGenerous", "(1 + 1) && true", 10, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessUnlimited", "(1 + 1) && true", + -1, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrChildLimitExceeded", "(1 + 1) || true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"OrParentLimitExceeded", "(1 + 1) || true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"OrLimitSuccess", "(1 + 1) || true", 3, + absl::StatusCode::kOk, ""})); + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockAndError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_&&_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockOrError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_||_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/parser/options.h b/parser/options.h index 916a941f0..719bed454 100644 --- a/parser/options.h +++ b/parser/options.h @@ -62,6 +62,10 @@ struct ParserOptions final { // Limited to field specifiers in select and message creation, // enabled by default bool enable_quoted_identifiers = true; + + // Enables parsing logical AND & OR operators as a single flat variadic call + // instead of a balanced/nested binary AST structure. + bool enable_variadic_logical_operators = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index 709e2fd41..6c6434319 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -552,7 +552,7 @@ class ExpressionBalancer final { // balance creates a balanced tree from the sub-terms and returns the final // Expr value. - Expr Balance(); + Expr Balance(bool enable_variadic = false); private: // balancedTree recursively balances the terms provided to a commutative @@ -577,10 +577,13 @@ void ExpressionBalancer::AddTerm(int64_t op, Expr term) { ops_.push_back(op); } -Expr ExpressionBalancer::Balance() { +Expr ExpressionBalancer::Balance(bool enable_variadic) { if (terms_.size() == 1) { return std::move(terms_[0]); } + if (enable_variadic) { + return factory_.NewCall(ops_[0], function_, std::move(terms_)); + } return BalancedTree(0, ops_.size() - 1); } @@ -620,7 +623,8 @@ class ParserVisitor final : public CelBaseVisitor, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, - bool enable_quoted_identifiers = false) + bool enable_quoted_identifiers = false, + bool enable_variadic_logical_operators = false) : source_(source), factory_(source_), macro_registry_(macro_registry), @@ -628,7 +632,8 @@ class ParserVisitor final : public CelBaseVisitor, max_recursion_depth_(max_recursion_depth), add_macro_calls_(add_macro_calls), enable_optional_syntax_(enable_optional_syntax), - enable_quoted_identifiers_(enable_quoted_identifiers) {} + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} ~ParserVisitor() override = default; @@ -719,6 +724,7 @@ class ParserVisitor final : public CelBaseVisitor, const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; + const bool enable_variadic_logical_operators_; }; template ParseImpl( ExprRecursionListener listener(options.max_recursion_depth); ParserVisitor visitor( source, options.max_recursion_depth, registry, options.add_macro_calls, - options.enable_optional_syntax, options.enable_quoted_identifiers); + options.enable_optional_syntax, options.enable_quoted_identifiers, + options.enable_variadic_logical_operators); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 587b63a30..33c52b1d2 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,59 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +struct VariadicLogicalOperatorsTestCase { + std::string input; + std::string expected_adorned_string; +}; + +class VariadicLogicalOperatorsTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalOperatorsTest, Parse) { + const auto& test_case = GetParam(); + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.input)); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.Print(ast->root_expr()); + EXPECT_EQ(adorned_string, test_case.expected_adorned_string); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalOperators, VariadicLogicalOperatorsTest, + testing::Values( + VariadicLogicalOperatorsTestCase{ + .input = "a && b && c && d", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a || b || c || d", + .expected_adorned_string = "_||_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a && b && (c || d || e)", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " _||_(\n" + " c^#4:Expr.Ident#,\n" + " d^#5:Expr.Ident#,\n" + " e^#7:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#3:Expr.Call#"})); + TEST(ParserTest, ParseFailurePopulatesIssues) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); diff --git a/tools/cel_unparser.cc b/tools/cel_unparser.cc index 28a1187bb..741d91208 100644 --- a/tools/cel_unparser.cc +++ b/tools/cel_unparser.cc @@ -150,6 +150,8 @@ class Unparser { // - a ternary conditional operator bool IsBinaryOrTernaryOperator(const Expr& expr); + bool IsLogicalOperator(absl::string_view op); + template void Print(Ts&&... args) { absl::StrAppend(&output_, std::forward(args)...); @@ -436,6 +438,24 @@ absl::Status Unparser::VisitUnary(const Expr::Call& expr, absl::Status Unparser::VisitBinary(const Expr::Call& expr, const std::string& op) { + if (expr.args_size() < 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& fun = expr.function(); + if (IsLogicalOperator(fun)) { + for (int i = 0; i < expr.args_size(); ++i) { + if (i > 0) { + Print(kSpace, op, kSpace); + } + const auto& arg = expr.args(i); + bool arg_paren = IsComplexOperatorWithRespectTo(arg, fun); + CEL_RETURN_IF_ERROR(VisitMaybeNested(arg, arg_paren)); + } + return absl::OkStatus(); + } + if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); @@ -443,7 +463,6 @@ absl::Status Unparser::VisitBinary(const Expr::Call& expr, const auto& lhs = expr.args(0); const auto& rhs = expr.args(1); - const auto& fun = expr.function(); // add parens if the current operator is lower precedence than the lhs expr // operator. @@ -549,6 +568,10 @@ bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); } +bool Unparser::IsLogicalOperator(absl::string_view op) { + return op == CelOperator::LOGICAL_AND || op == CelOperator::LOGICAL_OR; +} + } // namespace absl::StatusOr Unparse(const Expr& expr, diff --git a/tools/cel_unparser_test.cc b/tools/cel_unparser_test.cc index 4cba4ce4d..aca6e91fd 100644 --- a/tools/cel_unparser_test.cc +++ b/tools/cel_unparser_test.cc @@ -67,6 +67,22 @@ INSTANTIATE_TEST_SUITE_P( {// Empty Expr error {"", absl::InvalidArgumentError("Unsupported Expr")}, + // Logical operators with too few arguments (single argument) + { + R"pb( + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + { + R"pb( + call_expr { + function: "_||_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + // Constants {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, {"const_expr{bool_value: true}", "true"}, @@ -619,6 +635,7 @@ TEST_P(UnparserTestTextExpr, Test) { options.add_macro_calls = true; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = true; ASSERT_OK_AND_ASSIGN(ParsedExpr result, Parse(GetParam().expr, "unparser", options)); @@ -779,6 +796,8 @@ INSTANTIATE_TEST_SUITE_P( {"has(a.`b.c`)", ""}, {"a.`b/c`", ""}, {"a.?`b/c`", ""}, + {"a && b && c && d", ""}, + {"a || b || c || d", ""}, })); } // namespace