From d503e172e7fc75623b0e765bf00cc63ca986185d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 13 Oct 2025 11:08:20 -0700 Subject: [PATCH] Add ProtoTypeMaskProvider, ProtoTypeMask, and FieldPath classes. The ProtoTypeMaskProvider has functions that can be used to validate the input field masks and to check whether a field is visible. PiperOrigin-RevId: 818745535 --- checker/internal/BUILD | 78 ++++ checker/internal/field_path.h | 73 ++++ checker/internal/field_path_test.cc | 78 ++++ checker/internal/proto_type_mask.h | 74 ++++ checker/internal/proto_type_mask_registry.cc | 189 ++++++++ checker/internal/proto_type_mask_registry.h | 87 ++++ .../internal/proto_type_mask_registry_test.cc | 409 ++++++++++++++++++ checker/internal/proto_type_mask_test.cc | 58 +++ 8 files changed, 1046 insertions(+) create mode 100644 checker/internal/field_path.h create mode 100644 checker/internal/field_path_test.cc create mode 100644 checker/internal/proto_type_mask.h create mode 100644 checker/internal/proto_type_mask_registry.cc create mode 100644 checker/internal/proto_type_mask_registry.h create mode 100644 checker/internal/proto_type_mask_registry_test.cc create mode 100644 checker/internal/proto_type_mask_test.cc diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 0f8f28f66..ff4135dd6 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -271,3 +271,81 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "field_path", + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + ], +) + +cc_library( + name = "proto_type_mask", + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..e16598d59 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,73 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + absl::string_view GetPath() const { return path_; } + + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + std::string GetFieldName() const { return field_selection_.front(); } + + std::string DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); + } + + private: + static inline constexpr char kPathDelimiter = '.'; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..2ed3e297f --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,78 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, DebugStringPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(field_path.DebugString(), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..e24fc5156 --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,74 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" + +namespace cel::checker_internal { + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::set& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + }; + + absl::string_view GetTypeName() const { return type_name_; } + + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + std::string DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", + type_name_, absl::StrJoin(paths, "', '")); + } + + private: + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..38c8f58c7 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,189 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({type_name.data(), {}}); + return absl::OkStatus(); +} + +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(field_name.data()); + return absl::OkStatus(); + } + types_and_visible_fields.insert({type_name.data(), {field_name.data()}}); + return absl::OkStatus(); +} + +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const ProtoTypeMask& proto_type_mask) { + CEL_ASSIGN_OR_RETURN( + const Descriptor* descriptor, + FindMessage(descriptor_pool, proto_type_mask.GetTypeName())); + absl::flat_hash_set field_names; + for (const FieldPath& field_path : proto_type_mask.GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +absl::StatusOr ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + return ProtoTypeMaskRegistry(types_and_visible_fields); +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..24e9513e3 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,87 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a set of field names for the input proto type mask. +// The set includes the first field name from each field path. +absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const ProtoTypeMask& proto_type_mask); + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create + static absl::StatusOr Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name); + + std::string DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, + "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..62dd68be2 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,409 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + absl::flat_hash_set field_names, + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask)); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + GetFieldNamesWithListOfFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + absl::flat_hash_set field_names, + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask)); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry.GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry.FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, DebugStringPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + ProtoTypeMaskRegistry proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry.DebugString(), + HasSubstr("ProtoTypeMaskRegistry { {type: " + "'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'} {type: " + "'cel.expr.conformance.proto3.TestAllTypes." + "NestedMessage', visible_fields: 'bb'} }")); +} +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..2cf418514 --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,58 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "checker/internal/field_path.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::set field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::set field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, DebugStringPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::set field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(proto_type_mask.DebugString(), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal