diff --git a/checker/internal/BUILD b/checker/internal/BUILD index f4c60f937..693d17f99 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -66,6 +66,8 @@ cc_library( hdrs = ["type_check_env.h"], deps = [ ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", "//common:constant", "//common:container", "//common:decl", @@ -76,6 +78,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -129,6 +132,8 @@ cc_library( deps = [ ":format_type_name", ":namespace_generator", + ":proto_type_mask", + ":proto_type_mask_registry", ":type_check_env", ":type_inference_context", "//checker:checker_options", @@ -310,3 +315,88 @@ cc_test( "@com_google_absl//absl/types:optional", ], ) + +cc_library( + name = "field_path", + srcs = ["field_path.cc"], + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask", + srcs = ["proto_type_mask.cc"], + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/field_path.cc b/checker/internal/field_path.cc new file mode 100644 index 000000000..5ecc4219b --- /dev/null +++ b/checker/internal/field_path.cc @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace cel::checker_internal { + +std::string FieldPath::DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..79a8285f8 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + absl::string_view GetPath() const { return path_; } + + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + std::string GetFieldName() const { return field_selection_.front(); } + + template + friend void AbslStringify(Sink& sink, const FieldPath& field_path) { + absl::Format(&sink, "%v", field_path.DebugString()); + }; + + private: + static constexpr char kPathDelimiter = '.'; + + std::string DebugString() const; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..d79a0f02a --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,79 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, AbslStringifyPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(absl::StrCat(field_path), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.cc b/checker/internal/proto_type_mask.cc new file mode 100644 index 000000000..3c97ae950 --- /dev/null +++ b/checker/internal/proto_type_mask.cc @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" + +namespace cel::checker_internal { + +std::string ProtoTypeMask::DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", type_name_, + absl::StrJoin(paths, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..b861eece2 --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,66 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" + +namespace cel::checker_internal { + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::set& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + }; + + absl::string_view GetTypeName() const { return type_name_; } + + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + template + friend void AbslStringify(Sink& sink, const ProtoTypeMask& proto_type_mask) { + absl::Format(&sink, "%v", proto_type_mask.DebugString()); + } + + private: + std::string DebugString() const; + + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..53a4b0028 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,208 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({std::string(type_name), {}}); + return absl::OkStatus(); +} + +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(std::string(field_name)); + return absl::OkStatus(); + } + types_and_visible_fields.insert( + {std::string(type_name), {std::string(field_name)}}); + return absl::OkStatus(); +} + +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const ProtoTypeMask& proto_type_mask) { + CEL_ASSIGN_OR_RETURN( + const Descriptor* descriptor, + FindMessage(descriptor_pool, proto_type_mask.GetTypeName())); + std::set field_names; + for (const FieldPath& field_path : proto_type_mask.GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +absl::StatusOr> +ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + std::shared_ptr proto_type_mask_registry = + absl::WrapUnique(new ProtoTypeMaskRegistry(types_and_visible_fields)); + return proto_type_mask_registry; +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +std::string ProtoTypeMaskRegistry::DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..97a0b279d --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a set of field names for the input proto type mask. +// The set includes the first field name from each field path. +absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const ProtoTypeMask& proto_type_mask); + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create + static absl::StatusOr> Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name); + + template + friend void AbslStringify( + Sink& sink, + const std::shared_ptr& proto_type_mask_registry) { + absl::Format(&sink, "%v", proto_type_mask_registry->DebugString()); + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + std::string DebugString() const; + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..6afaa0214 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,415 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + std::set field_names, + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask)); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + GetFieldNamesWithListOfFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + std::set field_names, + GetFieldNames(GetSharedTestingDescriptorPool().get(), proto_type_mask)); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, AbslStringifyPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + absl::StrCat(proto_type_mask_registry), + AllOf(HasSubstr("ProtoTypeMaskRegistry {"), + HasSubstr("{type: 'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'}"), + HasSubstr("{type: " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'" + ", visible_fields: 'bb'}"))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..ec46e4ced --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,59 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "checker/internal/field_path.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::set field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::set field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, AbslStringifyPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::set field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(absl::StrCat(proto_type_mask), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index c080326cb..c677b1be4 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -96,6 +96,9 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { + if (!proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return absl::nullopt; + } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the // same name -- the later type provider will still be considered when diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 15f8ecc4d..00fea0ba3 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -25,16 +25,20 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "common/constant.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -154,6 +158,14 @@ class TypeCheckEnv { variables_[decl.name()] = std::move(decl); } + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -224,6 +236,8 @@ class TypeCheckEnv { absl::flat_hash_map variables_; absl::flat_hash_map functions_; + std::shared_ptr proto_type_mask_registry_; + // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 94a05602e..4fbb2c948 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -24,12 +25,15 @@ #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" @@ -86,10 +90,18 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } absl::Status AddWellKnownContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, - bool use_json_name) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } Type type = MessageTypeField(field).GetType(); if (type.IsEnum()) { type = IntType(); @@ -109,11 +121,14 @@ absl::Status AddWellKnownContextDeclarationVariables( } absl::Status AddContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { const bool use_json_name = env.proto_type_introspector().use_json_name(); if (IsWellKnownMessageType(descriptor)) { - return AddWellKnownContextDeclarationVariables(descriptor, env, - use_json_name); + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); } CEL_ASSIGN_OR_RETURN(auto fields, env.proto_type_introspector().ListFieldsForStructType( @@ -317,7 +332,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } for (const google::protobuf::Descriptor* context_type : config.context_types) { - CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); } for (VariableDeclRecord& var : config.variables) { @@ -339,6 +355,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } } + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + return absl::OkStatus(); } @@ -427,6 +445,15 @@ absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) { return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddVariableWithTypeMask( + const VariableDecl& decl, std::set field_paths) { + ProtoTypeMask proto_type_mask(std::string(decl.type().name()), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddVariable(decl)); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( const VariableDecl& decl) { CEL_RETURN_IF_ERROR( @@ -462,6 +489,22 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithTypeMask( + absl::string_view type, std::set field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN(std::set field_names, + GetFieldNames(descriptor_pool_.get(), proto_type_mask)); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 646a5d16f..d5a5f09ec 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ #include +#include #include #include #include @@ -28,6 +29,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -74,8 +76,12 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddLibrarySubset(TypeCheckerSubset subset) override; absl::Status AddVariable(const VariableDecl& decl) override; + absl::Status AddVariableWithTypeMask( + const VariableDecl& decl, std::set field_paths) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithTypeMask( + absl::string_view type, std::set field_paths) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; @@ -130,6 +136,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector functions; std::vector> type_providers; std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 5dd1f5256..c73c4de1d 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ #include +#include #include #include "absl/base/nullability.h" @@ -84,6 +85,20 @@ class TypeCheckerBuilder { // with the resulting type checker. virtual absl::Status AddVariable(const VariableDecl& decl) = 0; + // Adds a variable declaration that may be referenced in expressions checked + // with the resulting type checker. + // + // Adds a ProtoTypeMask (similar to a FieldMask) from the field paths that + // defines the visible fields for each type. A set of field paths that is + // the empty set defines an opaque type where the variable can be used in + // functions, but the fields aren't directly accessible in the CEL expression. + // + // It is an error if the variable's type is not a message type. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddVariableWithTypeMask( + const VariableDecl& decl, std::set field_paths) = 0; + // Adds a variable declaration that may be referenced in expressions checked // with the resulting type checker. // @@ -102,6 +117,24 @@ class TypeCheckerBuilder { // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // Adds a ProtoTypeMask (similar to a FieldMask) from the field paths that + // defines the visible fields for each type. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithTypeMask( + absl::string_view type, std::set field_paths) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0;