From 678118dad4bd104426f49dc2537f6a999477df7e Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 16 Oct 2025 11:05:29 -0700 Subject: [PATCH] Add the function AddContextDeclarationWithProtoTypeMask to the type checker. PiperOrigin-RevId: 820311481 --- checker/internal/BUILD | 6 + checker/internal/type_check_env.cc | 5 + checker/internal/type_check_env.h | 14 + checker/internal/type_checker_builder_impl.cc | 56 +++- checker/internal/type_checker_builder_impl.h | 9 + .../type_checker_builder_impl_test.cc | 278 ++++++++++++++++++ checker/type_checker_builder.h | 19 ++ checker/type_checker_builder_factory_test.cc | 47 +++ 8 files changed, 428 insertions(+), 6 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 25550616a..c518cf0b7 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -66,6 +66,8 @@ cc_library( hdrs = ["type_check_env.h"], deps = [ ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", "//common:constant", "//common:container", "//common:decl", @@ -76,6 +78,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -129,6 +132,7 @@ cc_library( deps = [ ":format_type_name", ":namespace_generator", + ":proto_type_mask", ":type_check_env", ":type_inference_context", "//checker:checker_options", @@ -153,6 +157,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", @@ -225,6 +230,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 763d9ba46..47487220c 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -96,6 +97,10 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { + if (proto_type_mask_registry_ != nullptr && + !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return absl::nullopt; + } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the // same name -- the later type provider will still be considered when diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 15f8ecc4d..00fea0ba3 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -25,16 +25,20 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "common/constant.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -154,6 +158,14 @@ class TypeCheckEnv { variables_[decl.name()] = std::move(decl); } + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -224,6 +236,8 @@ class TypeCheckEnv { absl::flat_hash_map variables_; absl::flat_hash_map functions_; + std::shared_ptr proto_type_mask_registry_; + // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 85b581e83..9b91fc926 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,13 +24,16 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" @@ -86,10 +90,19 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } absl::Status AddWellKnownContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, - bool use_json_name) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } Type type = MessageTypeField(field).GetType(); if (type.IsEnum()) { type = IntType(); @@ -109,11 +122,15 @@ absl::Status AddWellKnownContextDeclarationVariables( } absl::Status AddContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { const bool use_json_name = env.proto_type_introspector().use_json_name(); if (IsWellKnownMessageType(descriptor)) { - return AddWellKnownContextDeclarationVariables(descriptor, env, - use_json_name); + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); } CEL_ASSIGN_OR_RETURN(auto fields, env.proto_type_introspector().ListFieldsForStructType( @@ -131,6 +148,13 @@ absl::Status AddContextDeclarationVariables( absl::string_view name = field_entry.name; + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(name)) { + continue; + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( absl::StrCat("variable '", name, @@ -317,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } for (const google::protobuf::Descriptor* context_type : config.context_types) { - CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); } for (VariableDeclRecord& var : config.variables) { @@ -339,6 +364,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } } + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + return absl::OkStatus(); } @@ -462,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(template_env_.descriptor_pool())); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 646a5d16f..9895a8aee 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -21,6 +21,7 @@ #include #include "absl/base/nullability.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -28,6 +29,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -76,6 +78,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; @@ -130,6 +134,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector functions; std::vector> type_providers; std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 494e7e440..f81b3098f 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -15,12 +15,15 @@ #include "checker/internal/type_checker_builder_impl.h" #include +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" @@ -107,6 +110,135 @@ INSTANTIATE_TEST_SUITE_P( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))})); +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnEmptyFieldPaths) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {}), + StatusIs(absl::StatusCode::kInvalidArgument, + "field paths cannot be the empty set")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnUnknownFieldPath) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"}), + StatusIs(absl::StatusCode::kInvalidArgument, + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'")); +} + +class ContextDeclsWithProtoTypeMaskFieldsDefinedTest + : public testing::TestWithParam {}; + +std::string LogFieldName(absl::string_view field_name, absl::string_view expr) { + return absl::StrCat("field_name: ", field_name, ", expr: ", expr); +} + +TEST_P(ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + ContextDeclsWithProtoTypeMaskFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {GetParam().expr}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + std::vector field_names = { + "single_int64", "single_uint32", "single_double", + "single_string", "single_any", "single_duration", + "single_bool_wrapper", "list_value", "standalone_message", + "standalone_enum", "repeated_bytes", "repeated_nested_message", + "map_int32_timestamp", "single_struct"}; + for (auto& field_name : field_names) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(field_name)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + if (field_name == GetParam().expr) { + // The field name that is part of the proto type mask is visible. + ASSERT_TRUE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type) + << LogFieldName(field_name, GetParam().expr); + } else { + // The field names that are not part of the proto type mask are not + // visible. + EXPECT_FALSE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAccess) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("payload.standalone_message.bb")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int32")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -120,6 +252,20 @@ TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { "already exists")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"}), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -129,6 +275,16 @@ TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { "context declaration 'com.example.UnknownType' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask("com.example.UnknownType", + {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -139,6 +295,17 @@ TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { "context declaration 'google.protobuf.Timestamp' is not a struct")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Timestamp", {"any_field_name"}), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + TEST(ContextDeclsTest, CustomStructNotSupported) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -160,6 +327,28 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { "context declaration 'com.example.MyStruct' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return absl::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "com.example.MyStruct", {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -179,6 +368,69 @@ TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + ErrorOnOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + NonOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.NestedTestAllTypes", + {"payload.single_int64"}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int32")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); +} + TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -193,6 +445,32 @@ TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { "variable 'single_int64' declared multiple times")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, NonOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), IsOk()); +} + TEST(TypeCheckerBuilderImplTest, InvalidTypeParamNameVariableValidationDisabled) { CheckerOptions options; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 5dd1f5256..8df9d2217 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -17,6 +17,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" @@ -102,6 +103,24 @@ class TypeCheckerBuilder { // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // Adds a ProtoTypeMask (similar to a FieldMask) from the field paths that + // defines the visible fields for each type. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 38430de5f..9c4775e7f 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -396,6 +396,27 @@ TEST(TypeCheckerBuilderTest, AddContextDeclaration) { EXPECT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, AddContextDeclarationWithProtoTypeMask) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, @@ -428,6 +449,32 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, + AllowWellKnownTypeContextDeclarationWithProtoTypeMask) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Any", {"value"}), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + // Visible field: value + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("value")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not visible field: type_url + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("type_url")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { CheckerOptions options; options.allow_well_known_type_context_declarations = true;